[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535)
This commit is contained in:
parent
379da6dcb5
commit
c833101740
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
MAX_SIZE_MB = 100
|
MAX_SIZE_MB = 150
|
||||||
|
|
||||||
|
|
||||||
def print_top_10_largest_files(zip_file):
|
def print_top_10_largest_files(zip_file):
|
||||||
|
@ -167,7 +167,7 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/layernorm_kernels.cu"
|
"csrc/layernorm_kernels.cu"
|
||||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
|
"csrc/quantization/fp8/common.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/moe_align_block_size_kernels.cu"
|
"csrc/moe_align_block_size_kernels.cu"
|
||||||
"csrc/pybind.cpp")
|
"csrc/pybind.cpp")
|
||||||
|
@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
|||||||
"Failed to determine torch nvcc compiler flags")
|
"Failed to determine torch nvcc compiler flags")
|
||||||
|
|
||||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||||
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
list(APPEND GPU_FLAGS "-DENABLE_FP8")
|
||||||
endif()
|
endif()
|
||||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||||
list(REMOVE_ITEM GPU_FLAGS
|
list(REMOVE_ITEM GPU_FLAGS
|
||||||
@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
|||||||
|
|
||||||
list(APPEND GPU_FLAGS
|
list(APPEND GPU_FLAGS
|
||||||
"-DUSE_ROCM"
|
"-DUSE_ROCM"
|
||||||
"-DENABLE_FP8_E4M3"
|
"-DENABLE_FP8"
|
||||||
"-U__HIP_NO_HALF_CONVERSIONS__"
|
"-U__HIP_NO_HALF_CONVERSIONS__"
|
||||||
"-U__HIP_NO_HALF_OPERATORS__"
|
"-U__HIP_NO_HALF_OPERATORS__"
|
||||||
"-fno-gpu-rdc")
|
"-fno-gpu-rdc")
|
||||||
|
@ -19,21 +19,17 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include "attention_dtypes.h"
|
#include "attention_dtypes.h"
|
||||||
#include "attention_utils.cuh"
|
#include "attention_utils.cuh"
|
||||||
|
|
||||||
#if defined(ENABLE_FP8_E5M2)
|
|
||||||
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
|
||||||
#elif defined(ENABLE_FP8_E4M3)
|
|
||||||
#include "../quantization/fp8/amd_detail/quant_utils.cuh"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
|
#include "../quantization/fp8/amd/quant_utils.cuh"
|
||||||
typedef __hip_bfloat16 __nv_bfloat16;
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
|
#else
|
||||||
|
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
@ -92,7 +88,7 @@ template<
|
|||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
bool IS_FP8_KV_CACHE,
|
vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||||
__device__ void paged_attention_kernel(
|
__device__ void paged_attention_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
@ -157,9 +153,7 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
|
||||||
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||||
#endif
|
|
||||||
|
|
||||||
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||||
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||||
@ -223,21 +217,14 @@ __device__ void paged_attention_kernel(
|
|||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||||
if constexpr (IS_FP8_KV_CACHE) {
|
|
||||||
#if defined(ENABLE_FP8_E5M2)
|
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
||||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
||||||
// Vector conversion from Quant_vec to K_vec.
|
|
||||||
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
|
||||||
#elif defined(ENABLE_FP8_E4M3)
|
|
||||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
||||||
// Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
|
|
||||||
// cache vec to k vec in higher precision (FP16, BFloat16, etc.)
|
|
||||||
k_vecs[j] = fp8_e4m3::scaled_vec_conversion<K_vec, Quant_vec>(k_vec_quant, kv_scale);
|
|
||||||
#else
|
|
||||||
assert(false);
|
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
} else {
|
||||||
|
// Vector conversion from Quant_vec to K_vec.
|
||||||
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||||
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(k_vec_quant, kv_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,9 +299,7 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
|
||||||
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||||
#endif
|
|
||||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||||
|
|
||||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||||
@ -348,21 +333,13 @@ __device__ void paged_attention_kernel(
|
|||||||
if (row_idx < HEAD_SIZE) {
|
if (row_idx < HEAD_SIZE) {
|
||||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
V_vec v_vec;
|
V_vec v_vec;
|
||||||
if constexpr (IS_FP8_KV_CACHE) {
|
|
||||||
#if defined(ENABLE_FP8_E5M2)
|
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
||||||
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
|
} else {
|
||||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
// Vector conversion from V_quant_vec to V_vec.
|
// Vector conversion from V_quant_vec to V_vec.
|
||||||
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, kv_scale);
|
||||||
#elif defined(ENABLE_FP8_E4M3)
|
|
||||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
|
||||||
// Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
|
|
||||||
// FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
|
|
||||||
v_vec = fp8_e4m3::scaled_vec_conversion<V_vec, V_quant_vec>(v_quant_vec, kv_scale);
|
|
||||||
#else
|
|
||||||
assert(false);
|
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
|
||||||
}
|
}
|
||||||
if (block_idx == num_seq_blocks - 1) {
|
if (block_idx == num_seq_blocks - 1) {
|
||||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||||
@ -448,7 +425,7 @@ template<
|
|||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
bool IS_FP8_KV_CACHE>
|
vllm::Fp8KVCacheDataType KV_DTYPE>
|
||||||
__global__ void paged_attention_v1_kernel(
|
__global__ void paged_attention_v1_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
@ -464,7 +441,7 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride,
|
const int kv_head_stride,
|
||||||
const float kv_scale) {
|
const float kv_scale) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>(
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||||
@ -477,7 +454,7 @@ template<
|
|||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
bool IS_FP8_KV_CACHE,
|
vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
int PARTITION_SIZE>
|
int PARTITION_SIZE>
|
||||||
__global__ void paged_attention_v2_kernel(
|
__global__ void paged_attention_v2_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
@ -496,7 +473,7 @@ __global__ void paged_attention_v2_kernel(
|
|||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride,
|
const int kv_head_stride,
|
||||||
const float kv_scale) {
|
const float kv_scale) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, PARTITION_SIZE>(
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
|
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||||
q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||||
@ -606,9 +583,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
IS_FP8_KV_CACHE>), shared_mem_size); \
|
KV_DTYPE>), shared_mem_size); \
|
||||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
IS_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
KV_DTYPE><<<grid, block, shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
query_ptr, \
|
query_ptr, \
|
||||||
key_cache_ptr, \
|
key_cache_ptr, \
|
||||||
@ -629,7 +606,7 @@ template<
|
|||||||
typename T,
|
typename T,
|
||||||
typename CACHE_T,
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
bool IS_FP8_KV_CACHE,
|
vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
int NUM_THREADS = 128>
|
int NUM_THREADS = 128>
|
||||||
void paged_attention_v1_launcher(
|
void paged_attention_v1_launcher(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
@ -706,36 +683,36 @@ void paged_attention_v1_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
||||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
||||||
out, \
|
out, \
|
||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
num_kv_heads, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
seq_lens, \
|
seq_lens, \
|
||||||
max_seq_len, \
|
max_seq_len, \
|
||||||
alibi_slopes, \
|
alibi_slopes, \
|
||||||
kv_scale);
|
kv_scale);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
|
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(
|
||||||
@ -752,65 +729,44 @@ void paged_attention_v1(
|
|||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype,
|
||||||
float kv_scale) {
|
float kv_scale) {
|
||||||
if (kv_cache_dtype == "auto") {
|
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE)
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
||||||
}
|
|
||||||
} else if (kv_cache_dtype == "fp8") {
|
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
IS_FP8_KV_CACHE, PARTITION_SIZE> \
|
KV_DTYPE, PARTITION_SIZE> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
exp_sums_ptr, \
|
exp_sums_ptr, \
|
||||||
max_logits_ptr, \
|
max_logits_ptr, \
|
||||||
tmp_out_ptr, \
|
tmp_out_ptr, \
|
||||||
query_ptr, \
|
query_ptr, \
|
||||||
key_cache_ptr, \
|
key_cache_ptr, \
|
||||||
value_cache_ptr, \
|
value_cache_ptr, \
|
||||||
num_kv_heads, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables_ptr, \
|
block_tables_ptr, \
|
||||||
seq_lens_ptr, \
|
seq_lens_ptr, \
|
||||||
max_num_blocks_per_seq, \
|
max_num_blocks_per_seq, \
|
||||||
alibi_slopes_ptr, \
|
alibi_slopes_ptr, \
|
||||||
q_stride, \
|
q_stride, \
|
||||||
kv_block_stride, \
|
kv_block_stride, \
|
||||||
kv_head_stride, \
|
kv_head_stride, \
|
||||||
kv_scale); \
|
kv_scale); \
|
||||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
exp_sums_ptr, \
|
exp_sums_ptr, \
|
||||||
max_logits_ptr, \
|
max_logits_ptr, \
|
||||||
tmp_out_ptr, \
|
tmp_out_ptr, \
|
||||||
seq_lens_ptr, \
|
seq_lens_ptr, \
|
||||||
max_num_partitions);
|
max_num_partitions);
|
||||||
|
|
||||||
template<
|
template<
|
||||||
typename T,
|
typename T,
|
||||||
typename CACHE_T,
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
bool IS_FP8_KV_CACHE,
|
vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
int NUM_THREADS = 128,
|
int NUM_THREADS = 128,
|
||||||
int PARTITION_SIZE = 512>
|
int PARTITION_SIZE = 512>
|
||||||
void paged_attention_v2_launcher(
|
void paged_attention_v2_launcher(
|
||||||
@ -897,39 +853,39 @@ void paged_attention_v2_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
||||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
|
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
||||||
out, \
|
out, \
|
||||||
exp_sums, \
|
exp_sums, \
|
||||||
max_logits, \
|
max_logits, \
|
||||||
tmp_out, \
|
tmp_out, \
|
||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
num_kv_heads, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
seq_lens, \
|
seq_lens, \
|
||||||
max_seq_len, \
|
max_seq_len, \
|
||||||
alibi_slopes, \
|
alibi_slopes, \
|
||||||
kv_scale);
|
kv_scale);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
|
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
@ -949,29 +905,7 @@ void paged_attention_v2(
|
|||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype,
|
||||||
float kv_scale) {
|
float kv_scale) {
|
||||||
if (kv_cache_dtype == "auto") {
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
||||||
}
|
|
||||||
} else if (kv_cache_dtype == "fp8") {
|
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef WARP_SIZE
|
#undef WARP_SIZE
|
||||||
|
@ -3,14 +3,21 @@
|
|||||||
#include "attention_generic.cuh"
|
#include "attention_generic.cuh"
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#ifdef ENABLE_FP8
|
||||||
|
#ifndef USE_ROCM
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#endif
|
#endif // USE_ROCM
|
||||||
|
#endif // ENABLE_FP8
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
|
||||||
// fp8 vector types for quantization of kv cache
|
|
||||||
|
|
||||||
|
enum class Fp8KVCacheDataType {
|
||||||
|
kAuto = 0,
|
||||||
|
kFp8E4M3 = 1,
|
||||||
|
kFp8E5M2 = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
// fp8 vector types for quantization of kv cache
|
||||||
template<>
|
template<>
|
||||||
struct Vec<uint8_t, 1> {
|
struct Vec<uint8_t, 1> {
|
||||||
using Type = uint8_t;
|
using Type = uint8_t;
|
||||||
@ -30,6 +37,5 @@ template<>
|
|||||||
struct Vec<uint8_t, 8> {
|
struct Vec<uint8_t, 8> {
|
||||||
using Type = uint2;
|
using Type = uint2;
|
||||||
};
|
};
|
||||||
#endif // ENABLE_FP8_E5M2
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -34,5 +34,7 @@ void reshape_and_cache_flash(
|
|||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8(
|
void convert_fp8(
|
||||||
|
torch::Tensor& dst_cache,
|
||||||
torch::Tensor& src_cache,
|
torch::Tensor& src_cache,
|
||||||
torch::Tensor& dst_cache);
|
const float scale,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
@ -4,10 +4,11 @@
|
|||||||
|
|
||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#if defined(ENABLE_FP8_E5M2)
|
|
||||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
#ifdef USE_ROCM
|
||||||
#elif defined(ENABLE_FP8_E4M3)
|
#include "quantization/fp8/amd/quant_utils.cuh"
|
||||||
#include "quantization/fp8/amd_detail/quant_utils.cuh"
|
#else
|
||||||
|
#include "quantization/fp8/nvidia/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@ -149,7 +150,7 @@ void copy_blocks(
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
|
template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
@ -194,19 +195,12 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
+ block_offset;
|
+ block_offset;
|
||||||
scalar_t tgt_key = key[src_key_idx];
|
scalar_t tgt_key = key[src_key_idx];
|
||||||
scalar_t tgt_value = value[src_value_idx];
|
scalar_t tgt_value = value[src_value_idx];
|
||||||
if constexpr (is_fp8_kv_cache) {
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||||
#if defined(ENABLE_FP8_E5M2)
|
|
||||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
|
||||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
|
||||||
#elif defined(ENABLE_FP8_E4M3)
|
|
||||||
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
|
|
||||||
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
|
|
||||||
#else
|
|
||||||
assert(false);
|
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
key_cache[tgt_key_idx] = tgt_key;
|
key_cache[tgt_key_idx] = tgt_key;
|
||||||
value_cache[tgt_value_idx] = tgt_value;
|
value_cache[tgt_value_idx] = tgt_value;
|
||||||
|
} else {
|
||||||
|
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
||||||
|
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -248,19 +242,22 @@ __global__ void reshape_and_cache_flash_kernel(
|
|||||||
}
|
}
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
|
// KV_T is the stored data type of kv-cache.
|
||||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
|
// CACHE_T is the data type of key and value tensors.
|
||||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||||
slot_mapping.data_ptr<int64_t>(), \
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||||
key_stride, \
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
value_stride, \
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
num_heads, \
|
slot_mapping.data_ptr<int64_t>(), \
|
||||||
head_size, \
|
key_stride, \
|
||||||
block_size, \
|
value_stride, \
|
||||||
x, \
|
num_heads, \
|
||||||
|
head_size, \
|
||||||
|
block_size, \
|
||||||
|
x, \
|
||||||
kv_scale);
|
kv_scale);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
@ -285,25 +282,8 @@ void reshape_and_cache(
|
|||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (kv_cache_dtype == "auto") {
|
|
||||||
if (key.dtype() == at::ScalarType::Float) {
|
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE)
|
||||||
CALL_RESHAPE_AND_CACHE(float, float, false);
|
|
||||||
} else if (key.dtype() == at::ScalarType::Half) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
|
||||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
|
||||||
}
|
|
||||||
} else if (kv_cache_dtype == "fp8") {
|
|
||||||
if (key.dtype() == at::ScalarType::Float) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
|
||||||
} else if (key.dtype() == at::ScalarType::Half) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
|
||||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void reshape_and_cache_flash(
|
void reshape_and_cache_flash(
|
||||||
@ -353,35 +333,34 @@ void reshape_and_cache_flash(
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename Tout, typename Tin>
|
template<typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void convert_fp8_kernel(
|
__global__ void convert_fp8_kernel(
|
||||||
const Tin* __restrict__ src_cache,
|
const Tin* __restrict__ src_cache,
|
||||||
Tout* __restrict__ dst_cache,
|
Tout* __restrict__ dst_cache,
|
||||||
|
const float kv_scale,
|
||||||
const int64_t block_stride) {
|
const int64_t block_stride) {
|
||||||
const int64_t block_idx = blockIdx.x;
|
const int64_t block_idx = blockIdx.x;
|
||||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
int64_t idx = block_idx * block_stride + i;
|
int64_t idx = block_idx * block_stride + i;
|
||||||
#if defined(ENABLE_FP8_E5M2)
|
dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
||||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
|
||||||
#elif defined(ENABLE_FP8_E4M3)
|
|
||||||
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
|
|
||||||
#else
|
|
||||||
assert(false);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define CALL_CONVERT_FP8(Tout, Tin) \
|
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
||||||
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||||
|
kv_scale, \
|
||||||
block_stride);
|
block_stride);
|
||||||
|
|
||||||
|
// Only for testing.
|
||||||
void convert_fp8(
|
void convert_fp8(
|
||||||
|
torch::Tensor& dst_cache,
|
||||||
torch::Tensor& src_cache,
|
torch::Tensor& src_cache,
|
||||||
torch::Tensor& dst_cache)
|
const float kv_scale,
|
||||||
|
const std::string& kv_cache_dtype)
|
||||||
{
|
{
|
||||||
torch::Device src_device = src_cache.device();
|
torch::Device src_device = src_cache.device();
|
||||||
torch::Device dst_device = dst_cache.device();
|
torch::Device dst_device = dst_cache.device();
|
||||||
@ -399,17 +378,35 @@ void convert_fp8(
|
|||||||
dim3 block(std::min(block_stride, int64_t(512)));
|
dim3 block(std::min(block_stride, int64_t(512)));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
if (kv_cache_dtype == "auto") {
|
||||||
CALL_CONVERT_FP8(uint8_t, float);
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
CALL_CONVERT_FP8(uint8_t, uint16_t);
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
CALL_CONVERT_FP8(float, uint8_t);
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
CALL_CONVERT_FP8(uint16_t, uint8_t);
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
|
}
|
||||||
|
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
||||||
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,12 +5,17 @@
|
|||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
#include <hip/hip_bfloat16.h>
|
#include <hip/hip_bfloat16.h>
|
||||||
|
|
||||||
|
#include "../../../attention/dtype_fp8.cuh"
|
||||||
#include "../../../attention/dtype_float32.cuh"
|
#include "../../../attention/dtype_float32.cuh"
|
||||||
#include "../../../attention/dtype_bfloat16.cuh"
|
#include "../../../attention/dtype_bfloat16.cuh"
|
||||||
|
|
||||||
namespace vllm
|
namespace vllm
|
||||||
{
|
{
|
||||||
namespace fp8_e4m3 {
|
#ifdef USE_ROCM
|
||||||
|
|
||||||
|
namespace fp8 {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
|
||||||
template <typename Tout, typename Tin>
|
template <typename Tout, typename Tin>
|
||||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||||
{
|
{
|
||||||
@ -512,6 +517,58 @@ __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint3
|
|||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
#endif // ENABLE_FP8
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
__inline__ __device__ Tout convert(const Tin &x) {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
|
return vec_conversion<Tout, Tin>(x);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
|
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following macro is used to dispatch the conversion function based on the
|
||||||
|
// data type of the key and value cache. The FN is a macro that calls a function
|
||||||
|
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
|
||||||
|
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||||
|
if (KV_DTYPE == "auto") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
} // fp8
|
||||||
|
#endif // USE_ROCM
|
||||||
} // namespace vllm
|
} // namespace vllm
|
568
csrc/quantization/fp8/nvidia/quant_utils.cuh
Normal file
568
csrc/quantization/fp8/nvidia/quant_utils.cuh
Normal file
@ -0,0 +1,568 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "../../../attention/attention_dtypes.h"
|
||||||
|
#include <assert.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
|
||||||
|
namespace fp8 {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
|
||||||
|
#if 0 // Disable the following code to reduce the binary size.
|
||||||
|
template <typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout
|
||||||
|
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
||||||
|
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
return res.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
|
||||||
|
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||||
|
tmp.u16[0] = res.x;
|
||||||
|
tmp.u16[1] = res.y;
|
||||||
|
return tmp.u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
|
||||||
|
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
|
||||||
|
tmp.u32[1] =
|
||||||
|
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
|
||||||
|
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
|
||||||
|
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
|
||||||
|
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// Note there is no direct convert function from fp8 to bf16.
|
||||||
|
// fp8 -> half
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
// half -> float -> bf16
|
||||||
|
float tmp = half_to_float(res.x);
|
||||||
|
return __float2bfloat16(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
|
||||||
|
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
|
||||||
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
|
||||||
|
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
|
||||||
|
res.y =
|
||||||
|
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
|
||||||
|
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
|
||||||
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float
|
||||||
|
vec_conversion<float, uint8_t>(const uint8_t &a,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// fp8 -> half
|
||||||
|
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
|
||||||
|
// half -> float
|
||||||
|
return half_to_float(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
|
||||||
|
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// fp8x2 -> half2
|
||||||
|
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
|
||||||
|
// half2 -> float2
|
||||||
|
return half2_to_float2(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
|
||||||
|
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ res;
|
||||||
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
|
||||||
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
|
||||||
|
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
|
||||||
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
|
||||||
|
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__half_raw tmp;
|
||||||
|
tmp.x = a;
|
||||||
|
__nv_fp8_storage_t res =
|
||||||
|
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
||||||
|
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
||||||
|
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
|
||||||
|
const float &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
|
||||||
|
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
|
||||||
|
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
half2 float16;
|
||||||
|
uint32_t uint32;
|
||||||
|
};
|
||||||
|
|
||||||
|
float16 = __float22half2_rn(a);
|
||||||
|
return uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
|
||||||
|
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
uint2 b;
|
||||||
|
float2 val;
|
||||||
|
val.x = a.x.x;
|
||||||
|
val.y = a.x.y;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||||
|
|
||||||
|
val.x = a.y.x;
|
||||||
|
val.y = a.y.y;
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||||
|
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
|
||||||
|
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
float4 b;
|
||||||
|
b.x = a.x.x;
|
||||||
|
b.y = a.x.y;
|
||||||
|
b.z = a.y.x;
|
||||||
|
b.w = a.y.y;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
|
||||||
|
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
uint4 b;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
|
||||||
|
b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
|
||||||
|
b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
|
||||||
|
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_bfloat162 b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
|
||||||
|
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
||||||
|
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_8_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||||
|
precision domains Convention of the scale in API, e.g: FP8_data =
|
||||||
|
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
|
||||||
|
Dequant(FP8) * scale => HP
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout scaled_vec_conversion(
|
||||||
|
const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
||||||
|
const uint8_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
return float_to_half(half_to_float(tmp.x) * scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||||
|
const uint16_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||||
|
tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
|
||||||
|
tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
|
||||||
|
return tmp.u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
||||||
|
const uint32_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] =
|
||||||
|
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||||
|
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
|
||||||
|
scale, fp8_type);
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4
|
||||||
|
scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
|
||||||
|
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat16
|
||||||
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
||||||
|
const uint8_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// Note there is no direct convert function from fp8 to bf16.
|
||||||
|
// fp8 -> half
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
// half -> float -> bf16
|
||||||
|
float tmp = half_to_float(res.x);
|
||||||
|
return __float2bfloat16(tmp * scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162
|
||||||
|
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
||||||
|
const uint16_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
||||||
|
fp8_type);
|
||||||
|
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
|
||||||
|
scale, fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||||
|
const uint32_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
||||||
|
fp8_type);
|
||||||
|
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||||
|
scale, fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
||||||
|
const uint2 &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
||||||
|
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||||
|
const uint8_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
uint16_t tmp = res.x;
|
||||||
|
|
||||||
|
// half -> float
|
||||||
|
return half_to_float(tmp) * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
||||||
|
const uint16_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// fp8x2 -> half2
|
||||||
|
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
||||||
|
// half2 -> float2
|
||||||
|
return half2_to_float2(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
||||||
|
const uint32_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ res;
|
||||||
|
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||||
|
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
|
||||||
|
fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
||||||
|
const uint2 &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
||||||
|
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
||||||
|
const uint16_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_fp8_storage_t res =
|
||||||
|
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||||
|
const __nv_bfloat16 &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||||
|
__NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
||||||
|
const float &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_fp8_storage_t res =
|
||||||
|
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
||||||
|
const uint32_t &a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
#endif // ENABLE_FP8
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
__inline__ __device__ Tout convert(const Tin &x) {
|
||||||
|
#if 0 // Disable the following code to reduce the binary size.
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
|
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
||||||
|
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||||
|
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
|
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
||||||
|
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||||
|
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following macro is used to dispatch the conversion function based on the
|
||||||
|
// data type of the key and value cache. The FN is a macro that calls a function
|
||||||
|
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
|
||||||
|
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||||
|
if (KV_DTYPE == "auto") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else if (KV_DTYPE == "fp8_e5m2") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fp8
|
||||||
|
#endif // not USE_ROCM
|
||||||
|
} // namespace vllm
|
@ -1,277 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <float.h>
|
|
||||||
#include <type_traits>
|
|
||||||
#include "../../attention/attention_dtypes.h"
|
|
||||||
#include "../../attention/dtype_float32.cuh"
|
|
||||||
#include "../../attention/dtype_float16.cuh"
|
|
||||||
#include "../../attention/dtype_bfloat16.cuh"
|
|
||||||
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
#ifdef ENABLE_FP8_E5M2
|
|
||||||
namespace fp8_e5m2_unscaled {
|
|
||||||
|
|
||||||
template<typename Tout, typename Tin>
|
|
||||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
|
||||||
{
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8 -> half
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
|
||||||
{
|
|
||||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
|
||||||
return res.x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> half2
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint16_t u16[2];
|
|
||||||
uint32_t u32;
|
|
||||||
} tmp;
|
|
||||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
|
||||||
tmp.u16[0] = res.x;
|
|
||||||
tmp.u16[1] = res.y;
|
|
||||||
return tmp.u32;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> half2x2
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint2 u32x2;
|
|
||||||
uint32_t u32[2];
|
|
||||||
} tmp;
|
|
||||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
|
||||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
|
||||||
return tmp.u32x2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> half2x4
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint4 u64x2;
|
|
||||||
uint2 u64[2];
|
|
||||||
} tmp;
|
|
||||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
|
||||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
|
||||||
return tmp.u64x2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8 -> __nv_bfloat16
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
|
||||||
{
|
|
||||||
// Note there is no direct convert function from fp8 to bf16.
|
|
||||||
// fp8 -> half
|
|
||||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
|
||||||
// half -> float -> bf16
|
|
||||||
float tmp = half_to_float(res.x);
|
|
||||||
return __float2bfloat16(tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> __nv_bfloat162
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
__nv_bfloat162 res;
|
|
||||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
|
||||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> bf16_4_t
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
bf16_4_t res;
|
|
||||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
|
||||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> bf16_8_t
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
|
||||||
{
|
|
||||||
bf16_4_t tmp1, tmp2;
|
|
||||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
|
||||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
|
||||||
bf16_8_t res;
|
|
||||||
res.x = tmp1.x;
|
|
||||||
res.y = tmp1.y;
|
|
||||||
res.z = tmp2.x;
|
|
||||||
res.w = tmp2.y;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8 -> float
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
|
||||||
{
|
|
||||||
// fp8 -> half
|
|
||||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
|
||||||
// half -> float
|
|
||||||
return half_to_float(tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> float2
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
// fp8x2 -> half2
|
|
||||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
|
||||||
// half2 -> float2
|
|
||||||
return half2_to_float2(tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> float4
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
Float4_ res;
|
|
||||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
|
||||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> float8
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
|
||||||
{
|
|
||||||
Float4_ tmp1, tmp2;
|
|
||||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
|
||||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
|
||||||
Float8_ res;
|
|
||||||
res.x = tmp1.x;
|
|
||||||
res.y = tmp1.y;
|
|
||||||
res.z = tmp2.x;
|
|
||||||
res.w = tmp2.y;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// half -> fp8
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
__half_raw tmp;
|
|
||||||
tmp.x = a;
|
|
||||||
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
|
||||||
return (uint8_t)res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// bf16 -> fp8
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
|
||||||
{
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
||||||
assert(false);
|
|
||||||
#else
|
|
||||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
|
||||||
return (uint8_t)res;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// float -> fp8
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
|
||||||
{
|
|
||||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
|
||||||
return (uint8_t)res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> float4
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
|
||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
half2 float16;
|
|
||||||
uint32_t uint32;
|
|
||||||
};
|
|
||||||
|
|
||||||
float16 = __float22half2_rn(a);
|
|
||||||
return uint32;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
|
||||||
{
|
|
||||||
uint2 b;
|
|
||||||
float2 val;
|
|
||||||
val.x = a.x.x;
|
|
||||||
val.y = a.x.y;
|
|
||||||
b.x = vec_conversion<uint32_t, float2>(val);
|
|
||||||
|
|
||||||
val.x = a.y.x;
|
|
||||||
val.y = a.y.y;
|
|
||||||
b.y = vec_conversion<uint32_t, float2>(val);
|
|
||||||
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
|
||||||
{
|
|
||||||
float4 b;
|
|
||||||
b.x = a.x.x;
|
|
||||||
b.y = a.x.y;
|
|
||||||
b.z = a.y.x;
|
|
||||||
b.w = a.y.y;
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
|
||||||
{
|
|
||||||
uint4 b;
|
|
||||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
|
||||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
|
||||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
|
||||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
|
|
||||||
__nv_bfloat162 b;
|
|
||||||
from_float(b, a);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
|
||||||
bf16_4_t b;
|
|
||||||
from_float(b, a);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
|
||||||
bf16_8_t b;
|
|
||||||
from_float(b, a);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace fp8_e5m2_unscaled
|
|
||||||
#endif // ENABLE_FP8_E5M2
|
|
||||||
} // namespace vllm
|
|
@ -236,14 +236,14 @@ def test_paged_attention(
|
|||||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device)
|
device=device)
|
||||||
ops.convert_fp8(key_cache, dequantized_key_cache)
|
ops.convert_fp8(dequantized_key_cache, key_cache)
|
||||||
key_cache = dequantized_key_cache
|
key_cache = dequantized_key_cache
|
||||||
|
|
||||||
value_cache_shape = value_cache.shape
|
value_cache_shape = value_cache.shape
|
||||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device)
|
device=device)
|
||||||
ops.convert_fp8(value_cache, dequantized_value_cache)
|
ops.convert_fp8(dequantized_value_cache, value_cache)
|
||||||
value_cache = dequantized_value_cache
|
value_cache = dequantized_value_cache
|
||||||
|
|
||||||
ref_output = torch.empty_like(query)
|
ref_output = torch.empty_like(query)
|
||||||
|
@ -5,8 +5,6 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm._C import cache_ops
|
|
||||||
from vllm.utils import is_hip
|
|
||||||
|
|
||||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
@ -25,6 +23,8 @@ SEEDS = [0]
|
|||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# We assume fp8 is always enabled for testing.
|
||||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||||
|
|
||||||
|
|
||||||
@ -124,8 +124,6 @@ def test_reshape_and_cache(
|
|||||||
device: str,
|
device: str,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not is_hip() and kv_cache_dtype == "fp8":
|
|
||||||
pytest.skip() # This test is not tuned for e5m2 cuda precision
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -149,9 +147,9 @@ def test_reshape_and_cache(
|
|||||||
# Clone the KV caches.
|
# Clone the KV caches.
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||||
ops.convert_fp8(key_cache, cloned_key_cache)
|
ops.convert_fp8(cloned_key_cache, key_cache)
|
||||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||||
ops.convert_fp8(value_cache, cloned_value_cache)
|
ops.convert_fp8(cloned_value_cache, value_cache)
|
||||||
else:
|
else:
|
||||||
cloned_key_cache = key_cache.clone()
|
cloned_key_cache = key_cache.clone()
|
||||||
cloned_value_cache = value_cache.clone()
|
cloned_value_cache = value_cache.clone()
|
||||||
@ -165,9 +163,9 @@ def test_reshape_and_cache(
|
|||||||
|
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||||
ops.convert_fp8(key_cache, result_key_cache)
|
ops.convert_fp8(result_key_cache, key_cache)
|
||||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||||
ops.convert_fp8(value_cache, result_value_cache)
|
ops.convert_fp8(result_value_cache, value_cache)
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||||
@ -255,8 +253,8 @@ def test_reshape_and_cache_flash(
|
|||||||
cloned_value_cache = value_cache.clone()
|
cloned_value_cache = value_cache.clone()
|
||||||
|
|
||||||
# Call the reshape_and_cache kernel.
|
# Call the reshape_and_cache kernel.
|
||||||
cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||||
slot_mapping, kv_cache_dtype)
|
slot_mapping, kv_cache_dtype)
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
||||||
@ -299,8 +297,6 @@ def test_swap_blocks(
|
|||||||
) -> None:
|
) -> None:
|
||||||
if kv_cache_dtype == "fp8" and "cpu" in direction:
|
if kv_cache_dtype == "fp8" and "cpu" in direction:
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
if not is_hip() and kv_cache_dtype == "fp8":
|
|
||||||
pytest.skip() # This test is not tuned for e5m2 cuda precision
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -348,7 +344,6 @@ def test_swap_blocks(
|
|||||||
dist_value_caches[0][dst].cpu())
|
dist_value_caches[0][dst].cpu())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@ -357,7 +352,7 @@ def test_swap_blocks(
|
|||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_fp8_conversion(
|
def test_fp8_e4m3_conversion(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -377,9 +372,9 @@ def test_fp8_conversion(
|
|||||||
cache.uniform_(low, high)
|
cache.uniform_(low, high)
|
||||||
|
|
||||||
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
||||||
ops.convert_fp8(cache, cache_fp8)
|
ops.convert_fp8(cache_fp8, cache)
|
||||||
|
|
||||||
converted_cache = torch.empty_like(cache)
|
converted_cache = torch.empty_like(cache)
|
||||||
ops.convert_fp8(cache_fp8, converted_cache)
|
ops.convert_fp8(converted_cache, cache_fp8)
|
||||||
|
|
||||||
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
|
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||||
|
@ -270,8 +270,11 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
|||||||
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
|
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
|
||||||
|
|
||||||
|
|
||||||
def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
|
def convert_fp8(output: torch.Tensor,
|
||||||
vllm_cache_ops.convert_fp8(output, input)
|
input: torch.Tensor,
|
||||||
|
scale: float = 1.0,
|
||||||
|
kv_dtype: str = "fp8") -> None:
|
||||||
|
vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype)
|
||||||
|
|
||||||
|
|
||||||
#TODO: cuda_utils, custom_ar
|
#TODO: cuda_utils, custom_ar
|
||||||
|
@ -329,7 +329,7 @@ def _generate_random_fp8(
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
||||||
tensor_tmp.uniform_(low, high)
|
tensor_tmp.uniform_(low, high)
|
||||||
ops.convert_fp8(tensor_tmp, tensor)
|
ops.convert_fp8(tensor, tensor_tmp)
|
||||||
del tensor_tmp
|
del tensor_tmp
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user