[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535)

This commit is contained in:
Cody Yu 2024-05-09 17:04:17 -07:00 committed by GitHub
parent 379da6dcb5
commit c833101740
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 843 additions and 558 deletions

View File

@ -1,7 +1,7 @@
import os
import zipfile
MAX_SIZE_MB = 100
MAX_SIZE_MB = 150
def print_top_10_largest_files(zip_file):

View File

@ -167,7 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")

View File

@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"Failed to determine torch nvcc compiler flags")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
list(APPEND GPU_FLAGS "-DENABLE_FP8")
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS
@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8_E4M3"
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")

View File

@ -19,21 +19,17 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#if defined(ENABLE_FP8_E5M2)
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "../quantization/fp8/amd_detail/quant_utils.cuh"
#endif
#include <algorithm>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#ifndef USE_ROCM
@ -92,7 +88,7 @@ template<
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_KV_CACHE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
@ -157,9 +153,7 @@ __device__ void paged_attention_kernel(
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
#endif
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
@ -223,21 +217,14 @@ __device__ void paged_attention_kernel(
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (IS_FP8_KV_CACHE) {
#if defined(ENABLE_FP8_E5M2)
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec.
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
#elif defined(ENABLE_FP8_E4M3)
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
// cache vec to k vec in higher precision (FP16, BFloat16, etc.)
k_vecs[j] = fp8_e4m3::scaled_vec_conversion<K_vec, Quant_vec>(k_vec_quant, kv_scale);
#else
assert(false);
#endif
} else {
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(k_vec_quant, kv_scale);
}
}
@ -312,9 +299,7 @@ __device__ void paged_attention_kernel(
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
#endif
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
@ -348,21 +333,13 @@ __device__ void paged_attention_kernel(
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
if constexpr (IS_FP8_KV_CACHE) {
#if defined(ENABLE_FP8_E5M2)
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
#elif defined(ENABLE_FP8_E4M3)
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
// FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
v_vec = fp8_e4m3::scaled_vec_conversion<V_vec, V_quant_vec>(v_quant_vec, kv_scale);
#else
assert(false);
#endif
} else {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, kv_scale);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
@ -448,7 +425,7 @@ template<
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_KV_CACHE>
vllm::Fp8KVCacheDataType KV_DTYPE>
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
@ -464,7 +441,7 @@ __global__ void paged_attention_v1_kernel(
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
@ -477,7 +454,7 @@ template<
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_KV_CACHE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
@ -496,7 +473,7 @@ __global__ void paged_attention_v2_kernel(
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride, kv_scale);
@ -606,9 +583,9 @@ __global__ void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_KV_CACHE>), shared_mem_size); \
KV_DTYPE>), shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
KV_DTYPE><<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
@ -629,7 +606,7 @@ template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_KV_CACHE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out,
@ -706,36 +683,36 @@ void paged_attention_v1_launcher(
}
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
break; \
case 16: \
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
break; \
case 32: \
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1(
@ -752,65 +729,44 @@ void paged_attention_v1(
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE)
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_KV_CACHE, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
num_kv_heads, \
scale, \
block_tables_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
seq_lens_ptr, \
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
num_kv_heads, \
scale, \
block_tables_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
seq_lens_ptr, \
max_num_partitions);
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_KV_CACHE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
@ -897,39 +853,39 @@ void paged_attention_v2_launcher(
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
out, \
exp_sums, \
max_logits, \
tmp_out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
out, \
exp_sums, \
max_logits, \
tmp_out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
break; \
case 16: \
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
break; \
case 32: \
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2(
@ -949,29 +905,7 @@ void paged_attention_v2(
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE)
}
#undef WARP_SIZE

View File

@ -3,14 +3,21 @@
#include "attention_generic.cuh"
#include <stdint.h>
#ifdef ENABLE_FP8_E5M2
#ifdef ENABLE_FP8
#ifndef USE_ROCM
#include <cuda_fp8.h>
#endif
#endif // USE_ROCM
#endif // ENABLE_FP8
namespace vllm {
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache
enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};
// fp8 vector types for quantization of kv cache
template<>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
@ -30,6 +37,5 @@ template<>
struct Vec<uint8_t, 8> {
using Type = uint2;
};
#endif // ENABLE_FP8_E5M2
} // namespace vllm

View File

@ -34,5 +34,7 @@ void reshape_and_cache_flash(
// Just for unittest
void convert_fp8(
torch::Tensor& dst_cache,
torch::Tensor& src_cache,
torch::Tensor& dst_cache);
const float scale,
const std::string& kv_cache_dtype);

View File

@ -4,10 +4,11 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#if defined(ENABLE_FP8_E5M2)
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "quantization/fp8/amd_detail/quant_utils.cuh"
#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif
#include <algorithm>
@ -149,7 +150,7 @@ void copy_blocks(
namespace vllm {
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
@ -194,19 +195,12 @@ __global__ void reshape_and_cache_kernel(
+ block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_kv_cache) {
#if defined(ENABLE_FP8_E5M2)
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#elif defined(ENABLE_FP8_E4M3)
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
#else
assert(false);
#endif
} else {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
}
}
}
@ -248,19 +242,22 @@ __global__ void reshape_and_cache_flash_kernel(
}
} // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, \
value_stride, \
num_heads, \
head_size, \
block_size, \
x, \
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, \
value_stride, \
num_heads, \
head_size, \
block_size, \
x, \
kv_scale);
void reshape_and_cache(
@ -285,25 +282,8 @@ void reshape_and_cache(
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "auto") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, float, false);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
}
} else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE)
}
void reshape_and_cache_flash(
@ -353,35 +333,34 @@ void reshape_and_cache_flash(
namespace vllm {
template<typename Tout, typename Tin>
template<typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const float kv_scale,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
#if defined(ENABLE_FP8_E5M2)
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
#elif defined(ENABLE_FP8_E4M3)
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
assert(false);
#endif
dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
}
}
} // namespace vllm
#define CALL_CONVERT_FP8(Tout, Tin) \
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
kv_scale, \
block_stride);
// Only for testing.
void convert_fp8(
torch::Tensor& dst_cache,
torch::Tensor& src_cache,
torch::Tensor& dst_cache)
const float kv_scale,
const std::string& kv_cache_dtype)
{
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
@ -399,17 +378,35 @@ void convert_fp8(
dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
if (kv_cache_dtype == "auto") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
}

View File

@ -5,12 +5,17 @@
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include "../../../attention/dtype_fp8.cuh"
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"
namespace vllm
{
namespace fp8_e4m3 {
#ifdef USE_ROCM
namespace fp8 {
#ifdef ENABLE_FP8
template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
@ -512,6 +517,58 @@ __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint3
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
#endif // ENABLE_FP8
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin &x) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x);
}
#endif
assert(false);
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
}
#endif
assert(false);
}
// The following macro is used to dispatch the conversion function based on the
// data type of the key and value cache. The FN is a macro that calls a function
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
} // fp8
#endif // USE_ROCM
} // namespace vllm

View File

@ -0,0 +1,568 @@
#pragma once
#include "../../../attention/attention_dtypes.h"
#include <assert.h>
#include <float.h>
#include <stdint.h>
#include <type_traits>
namespace vllm {
#ifndef USE_ROCM
namespace fp8 {
#ifdef ENABLE_FP8
#if 0 // Disable the following code to reduce the binary size.
template <typename Tout, typename Tin>
__inline__ __device__ Tout
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
return x;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
tmp.u16[0] = res.x;
tmp.u16[1] = res.y;
return tmp.u32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
tmp.u32[1] =
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
return tmp.u64x2;
}
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
// half -> float -> bf16
float tmp = half_to_float(res.x);
return __float2bfloat16(tmp);
}
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
__nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
res.y =
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float
vec_conversion<float, uint8_t>(const uint8_t &a,
const __nv_fp8_interpretation_t fp8_type) {
// fp8 -> half
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
// half -> float
return half_to_float(tmp);
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
// fp8x2 -> half2
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
// half2 -> float2
return half2_to_float2(tmp);
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
__half_raw tmp;
tmp.x = a;
__nv_fp8_storage_t res =
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
return (uint8_t)res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
return (uint8_t)res;
#endif
}
// float -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
const float &a, const __nv_fp8_interpretation_t fp8_type) {
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
return (uint8_t)res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
template <>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
return b;
}
template <>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
float4 b;
b.x = a.x.x;
b.y = a.x.y;
b.z = a.y.x;
b.w = a.y.y;
return b;
}
template <>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
return b;
}
template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
__nv_bfloat162 b;
from_float(b, a);
return b;
}
template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t b;
from_float(b, a);
return b;
}
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
bf16_8_t b;
from_float(b, a);
return b;
}
#endif
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains Convention of the scale in API, e.g: FP8_data =
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
Dequant(FP8) * scale => HP
*/
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(
const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
return x;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
const uint8_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
return float_to_half(half_to_float(tmp.x) * scale);
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
const uint16_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
return tmp.u32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
const uint32_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] =
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
scale, fp8_type);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4
scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
return tmp.u64x2;
}
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
const uint8_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
// half -> float -> bf16
float tmp = half_to_float(res.x);
return __float2bfloat16(tmp * scale);
}
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
const uint16_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
fp8_type);
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
scale, fp8_type);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
const uint32_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
fp8_type);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
scale, fp8_type);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
const uint2 &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
uint16_t tmp = res.x;
// half -> float
return half_to_float(tmp) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
const uint16_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
// fp8x2 -> half2
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
// half2 -> float2
return half2_to_float2(tmp);
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
const uint32_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
fp8_type);
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
const uint2 &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
const uint16_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
return (uint8_t)res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16 &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
__NV_SATFINITE, fp8_type);
return (uint8_t)res;
#endif
}
// float -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
const float &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
return (uint8_t)res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
const uint32_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
#endif // ENABLE_FP8
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin &x) {
#if 0 // Disable the following code to reduce the binary size.
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
}
#endif
assert(false);
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
}
#endif
assert(false);
}
// The following macro is used to dispatch the conversion function based on the
// data type of the key and value cache. The FN is a macro that calls a function
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
} // namespace fp8
#endif // not USE_ROCM
} // namespace vllm

View File

@ -1,277 +0,0 @@
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include "../../attention/attention_dtypes.h"
#include "../../attention/dtype_float32.cuh"
#include "../../attention/dtype_float16.cuh"
#include "../../attention/dtype_bfloat16.cuh"
namespace vllm {
#ifdef ENABLE_FP8_E5M2
namespace fp8_e5m2_unscaled {
template<typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
return x;
}
// fp8 -> half
template<>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
{
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
return res.x;
}
// fp8x2 -> half2
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
{
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
tmp.u16[0] = res.x;
tmp.u16[1] = res.y;
return tmp.u32;
}
// fp8x4 -> half2x2
template<>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
{
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
return tmp.u32x2;
}
// fp8x8 -> half2x4
template<>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
{
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
return tmp.u64x2;
}
// fp8 -> __nv_bfloat16
template<>
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
{
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
// half -> float -> bf16
float tmp = half_to_float(res.x);
return __float2bfloat16(tmp);
}
// fp8x2 -> __nv_bfloat162
template<>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
{
__nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res;
}
// fp8x4 -> bf16_4_t
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
{
bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> bf16_8_t
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
{
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template<>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
{
// fp8 -> half
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
// half -> float
return half_to_float(tmp);
}
// fp8x2 -> float2
template<>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
{
// fp8x2 -> half2
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
// half2 -> float2
return half2_to_float2(tmp);
}
// fp8x4 -> float4
template<>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
{
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> float8
template<>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
{
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
{
__half_raw tmp;
tmp.x = a;
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
}
// bf16 -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
#endif
}
// float -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
{
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
}
// fp8x4 -> float4
template<>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
{
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
{
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
template<>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
{
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
template<>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
{
float4 b;
b.x = a.x.x;
b.y = a.x.y;
b.z = a.y.x;
b.w = a.y.y;
return b;
}
template<>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
{
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w);
return b;
}
template<>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
__nv_bfloat162 b;
from_float(b, a);
return b;
}
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
bf16_4_t b;
from_float(b, a);
return b;
}
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
bf16_8_t b;
from_float(b, a);
return b;
}
} // namespace fp8_e5m2_unscaled
#endif // ENABLE_FP8_E5M2
} // namespace vllm

View File

@ -236,14 +236,14 @@ def test_paged_attention(
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(key_cache, dequantized_key_cache)
ops.convert_fp8(dequantized_key_cache, key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(value_cache, dequantized_value_cache)
ops.convert_fp8(dequantized_value_cache, value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)

View File

@ -5,8 +5,6 @@ import pytest
import torch
from vllm import _custom_ops as ops
from vllm._C import cache_ops
from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -25,6 +23,8 @@ SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]
@ -124,8 +124,6 @@ def test_reshape_and_cache(
device: str,
kv_cache_dtype: str,
) -> None:
if not is_hip() and kv_cache_dtype == "fp8":
pytest.skip() # This test is not tuned for e5m2 cuda precision
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
@ -149,9 +147,9 @@ def test_reshape_and_cache(
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(key_cache, cloned_key_cache)
ops.convert_fp8(cloned_key_cache, key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(value_cache, cloned_value_cache)
ops.convert_fp8(cloned_value_cache, value_cache)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
@ -165,9 +163,9 @@ def test_reshape_and_cache(
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(key_cache, result_key_cache)
ops.convert_fp8(result_key_cache, key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(value_cache, result_value_cache)
ops.convert_fp8(result_value_cache, value_cache)
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
@ -255,8 +253,8 @@ def test_reshape_and_cache_flash(
cloned_value_cache = value_cache.clone()
# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
@ -299,8 +297,6 @@ def test_swap_blocks(
) -> None:
if kv_cache_dtype == "fp8" and "cpu" in direction:
pytest.skip()
if not is_hip() and kv_cache_dtype == "fp8":
pytest.skip() # This test is not tuned for e5m2 cuda precision
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
@ -348,7 +344,6 @@ def test_swap_blocks(
dist_value_caches[0][dst].cpu())
@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@ -357,7 +352,7 @@ def test_swap_blocks(
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_conversion(
def test_fp8_e4m3_conversion(
num_heads: int,
head_size: int,
block_size: int,
@ -377,9 +372,9 @@ def test_fp8_conversion(
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
ops.convert_fp8(cache, cache_fp8)
ops.convert_fp8(cache_fp8, cache)
converted_cache = torch.empty_like(cache)
ops.convert_fp8(cache_fp8, converted_cache)
ops.convert_fp8(converted_cache, cache_fp8)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)

View File

@ -270,8 +270,11 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
vllm_cache_ops.convert_fp8(output, input)
def convert_fp8(output: torch.Tensor,
input: torch.Tensor,
scale: float = 1.0,
kv_dtype: str = "fp8") -> None:
vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype)
#TODO: cuda_utils, custom_ar

View File

@ -329,7 +329,7 @@ def _generate_random_fp8(
from vllm import _custom_ops as ops
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high)
ops.convert_fp8(tensor_tmp, tensor)
ops.convert_fp8(tensor, tensor_tmp)
del tensor_tmp