diff --git a/.gitignore b/.gitignore index b5195629..b1513ef0 100644 --- a/.gitignore +++ b/.gitignore @@ -181,6 +181,7 @@ _build/ # hip files generated by PyTorch *.hip *_hip* +hip_compat.h # Benchmark dataset *.json diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d90f4e7..6d0cf730 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") # # Supported/expected torch versions for CUDA/ROCm. diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index da02493b..d4ebd200 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -24,6 +24,7 @@ def main(args: argparse.Namespace): dtype=args.dtype, enforce_eager=args.enforce_eager, kv_cache_dtype=args.kv_cache_dtype, + quantization_param_path=args.quantization_param_path, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, enable_chunked_prefill=args.enable_chunked_prefill, @@ -127,10 +128,23 @@ if __name__ == '__main__': parser.add_argument( "--kv-cache-dtype", type=str, - choices=['auto', 'fp8_e5m2'], + choices=['auto', 'fp8'], default='auto', help= - 'Data type for kv cache storage. If "auto", will use model data type.') + 'Data type for kv cache storage. If "auto", will use model data type. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' + 'common inference criteria.') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') parser.add_argument( '--profile', action='store_true', diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 9d84bde1..d3e06cca 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -72,6 +72,7 @@ def run_vllm( max_model_len: Optional[int], enforce_eager: bool, kv_cache_dtype: str, + quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, gpu_memory_utilization: float = 0.9, @@ -89,6 +90,7 @@ def run_vllm( gpu_memory_utilization=gpu_memory_utilization, enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, device=device, enable_prefix_caching=enable_prefix_caching, download_dir=download_dir) @@ -217,7 +219,8 @@ def main(args: argparse.Namespace): args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, - args.kv_cache_dtype, args.device, + args.kv_cache_dtype, + args.quantization_param_path, args.device, args.enable_prefix_caching, args.gpu_memory_utilization, args.download_dir) elif args.backend == "hf": @@ -306,10 +309,23 @@ if __name__ == "__main__": parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8_e5m2"], + choices=["auto", "fp8"], default="auto", help= - 'Data type for kv cache storage. If "auto", will use model data type.') + 'Data type for kv cache storage. If "auto", will use model data type. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' + 'common inference criteria.') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') parser.add_argument( "--device", type=str, diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index f6c8f900..f71d1fca 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -97,6 +97,9 @@ def main( torch.cuda.cudart().cudaProfilerStart() start_time = time.perf_counter() + # Using default kv_scale + kv_scale = 1.0 + for _ in range(num_iters): if version == "v1": ops.paged_attention_v1( @@ -112,6 +115,7 @@ def main( max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) elif version == "v2": ops.paged_attention_v2( @@ -130,6 +134,7 @@ def main( max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) else: raise ValueError(f"Invalid version: {version}") @@ -179,11 +184,13 @@ if __name__ == '__main__': parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8_e5m2"], + choices=["auto", "fp8"], default="auto", help= - 'Data type for kv cache storage. If "auto", will use model data type.') - parser.add_argument("--device", type=str, choices=["cuda"], default="cuda") + 'Data type for kv cache storage. If "auto", will use model data type. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' + 'common inference criteria.') args = parser.parse_args() print(args) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index c7d3d853..4cb8a69f 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -117,6 +117,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) list(APPEND GPU_FLAGS "-DUSE_ROCM" + "-DENABLE_FP8_E4M3" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" "-fno-gpu-rdc") diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index 61748e6b..64f86381 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -4,4 +4,4 @@ #include "dtype_float16.cuh" #include "dtype_float32.cuh" #include "dtype_bfloat16.cuh" -#include "dtype_fp8_e5m2.cuh" +#include "dtype_fp8.cuh" diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 5e61668d..f3a5bbfd 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -22,12 +22,26 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" -#ifdef ENABLE_FP8_E5M2 + +#if defined(ENABLE_FP8_E5M2) #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" +#elif defined(ENABLE_FP8_E4M3) +#include "../quantization/fp8/amd_detail/quant_utils.cuh" #endif #include +#ifdef USE_ROCM + #include + typedef __hip_bfloat16 __nv_bfloat16; +#endif + +#ifndef USE_ROCM +#define WARP_SIZE 32 +#else +#define WARP_SIZE warpSize +#endif + #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -78,7 +92,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_E5M2_KV_CACHE, + bool IS_FP8_KV_CACHE, int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -95,7 +109,8 @@ __device__ void paged_attention_kernel( const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, - const int kv_head_stride) { + const int kv_head_stride, + const float kv_scale) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -142,7 +157,7 @@ __device__ void paged_attention_kernel( constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using Quant_vec = typename Vec::Type; #endif @@ -208,11 +223,16 @@ __device__ void paged_attention_kernel( const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - if constexpr (IS_FP8_E5M2_KV_CACHE) { -#ifdef ENABLE_FP8_E5M2 + if constexpr (IS_FP8_KV_CACHE) { +#if defined(ENABLE_FP8_E5M2) Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); // Vector conversion from Quant_vec to K_vec. k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); +#elif defined(ENABLE_FP8_E4M3) + Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k + // cache vec to k vec in higher precision (FP16, BFloat16, etc.) + k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale); #else assert(false); #endif @@ -292,7 +312,7 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using V_quant_vec = typename Vec::Type; #endif using Float_L_vec = typename FloatVec::Type; @@ -328,11 +348,16 @@ __device__ void paged_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; - if constexpr (IS_FP8_E5M2_KV_CACHE) { -#ifdef ENABLE_FP8_E5M2 + if constexpr (IS_FP8_KV_CACHE) { +#if defined(ENABLE_FP8_E5M2) V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); +#elif defined(ENABLE_FP8_E4M3) + V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert + // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.) + v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale); #else assert(false); #endif @@ -423,7 +448,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_E5M2_KV_CACHE> + bool IS_FP8_KV_CACHE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -437,11 +462,12 @@ __global__ void paged_attention_v1_kernel( const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( + const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs, max_num_partitions). @@ -451,7 +477,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_E5M2_KV_CACHE, + bool IS_FP8_KV_CACHE, int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -468,11 +494,12 @@ __global__ void paged_attention_v2_kernel( const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( + const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); + q_stride, kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs). @@ -579,9 +606,9 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ + IS_FP8_KV_CACHE>), shared_mem_size); \ vllm::paged_attention_v1_kernel<<>>( \ + IS_FP8_KV_CACHE><<>>( \ out_ptr, \ query_ptr, \ key_cache_ptr, \ @@ -594,14 +621,15 @@ __global__ void paged_attention_v2_reduce_kernel( alibi_slopes_ptr, \ q_stride, \ kv_block_stride, \ - kv_head_stride); + kv_head_stride, \ + kv_scale); // TODO(woosuk): Tune NUM_THREADS. template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_E5M2_KV_CACHE, + bool IS_FP8_KV_CACHE, int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, @@ -613,7 +641,8 @@ void paged_attention_v1_launcher( torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes, + float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -677,8 +706,8 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \ - paged_attention_v1_launcher( \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + paged_attention_v1_launcher( \ out, \ query, \ key_cache, \ @@ -688,20 +717,21 @@ void paged_attention_v1_launcher( block_tables, \ context_lens, \ max_context_len, \ - alibi_slopes); + alibi_slopes, \ + kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ switch (block_size) { \ case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \ + CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ break; \ case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \ + CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ break; \ case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \ + CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -720,7 +750,8 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { + const std::string& kv_cache_dtype, + float kv_scale) { if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Float) { CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); @@ -731,7 +762,7 @@ void paged_attention_v1( } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } - } else if (kv_cache_dtype == "fp8_e5m2") { + } else if (kv_cache_dtype == "fp8") { if (query.dtype() == at::ScalarType::Float) { CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); } else if (query.dtype() == at::ScalarType::Half) { @@ -748,7 +779,7 @@ void paged_attention_v1( #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ vllm::paged_attention_v2_kernel \ + IS_FP8_KV_CACHE, PARTITION_SIZE> \ <<>>( \ exp_sums_ptr, \ max_logits_ptr, \ @@ -764,7 +795,8 @@ void paged_attention_v1( alibi_slopes_ptr, \ q_stride, \ kv_block_stride, \ - kv_head_stride); \ + kv_head_stride, \ + kv_scale); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ out_ptr, \ @@ -778,7 +810,7 @@ template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_E5M2_KV_CACHE, + bool IS_FP8_KV_CACHE, int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( @@ -794,7 +826,8 @@ void paged_attention_v2_launcher( torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes, + float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -864,8 +897,8 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \ - paged_attention_v2_launcher( \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + paged_attention_v2_launcher( \ out, \ exp_sums, \ max_logits, \ @@ -878,20 +911,21 @@ void paged_attention_v2_launcher( block_tables, \ context_lens, \ max_context_len, \ - alibi_slopes); + alibi_slopes, \ + kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ switch (block_size) { \ case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ break; \ case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ break; \ case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -913,7 +947,8 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { + const std::string& kv_cache_dtype, + float kv_scale) { if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Float) { CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); @@ -924,7 +959,7 @@ void paged_attention_v2( } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } - } else if (kv_cache_dtype == "fp8_e5m2") { + } else if (kv_cache_dtype == "fp8") { if (query.dtype() == at::ScalarType::Float) { CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); } else if (query.dtype() == at::ScalarType::Half) { diff --git a/csrc/attention/dtype_fp8_e5m2.cuh b/csrc/attention/dtype_fp8.cuh similarity index 89% rename from csrc/attention/dtype_fp8_e5m2.cuh rename to csrc/attention/dtype_fp8.cuh index 0580fbb8..d11dee91 100644 --- a/csrc/attention/dtype_fp8_e5m2.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -8,7 +8,7 @@ #endif namespace vllm { -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) // fp8 vector types for quantization of kv cache template<> diff --git a/csrc/cache.h b/csrc/cache.h index 765e231a..718a5f6c 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -21,9 +21,10 @@ void reshape_and_cache( torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + const float kv_scale); // Just for unittest -void convert_fp8_e5m2( +void convert_fp8( torch::Tensor& src_cache, torch::Tensor& dst_cache); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7254010b..24aaa2ff 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,8 +4,10 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" +#elif defined(ENABLE_FP8_E4M3) +#include "quantization/fp8/amd_detail/quant_utils.cuh" #endif #include @@ -151,7 +153,7 @@ void copy_blocks( namespace vllm { -template +template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] @@ -163,7 +165,8 @@ __global__ void reshape_and_cache_kernel( const int num_heads, const int head_size, const int block_size, - const int x) { + const int x, + const float kv_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -195,10 +198,13 @@ __global__ void reshape_and_cache_kernel( + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; - if constexpr (is_fp8_e5m2_kv_cache) { -#ifdef ENABLE_FP8_E5M2 + if constexpr (is_fp8_kv_cache) { +#if defined(ENABLE_FP8_E5M2) key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); +#elif defined(ENABLE_FP8_E4M3) + key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); + value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); #else assert(false); #endif @@ -211,8 +217,8 @@ __global__ void reshape_and_cache_kernel( } // namespace vllm -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ - vllm::reshape_and_cache_kernel<<>>( \ +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ + vllm::reshape_and_cache_kernel<<>>( \ reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast(key_cache.data_ptr()), \ @@ -223,7 +229,8 @@ __global__ void reshape_and_cache_kernel( num_heads, \ head_size, \ block_size, \ - x); + x, \ + kv_scale); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -231,7 +238,8 @@ void reshape_and_cache( torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) + const std::string& kv_cache_dtype, + const float kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); @@ -254,7 +262,7 @@ void reshape_and_cache( } else if (key.dtype() == at::ScalarType::BFloat16) { CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); } - } else if (kv_cache_dtype == "fp8_e5m2") { + } else if (kv_cache_dtype == "fp8") { if (key.dtype() == at::ScalarType::Float) { CALL_RESHAPE_AND_CACHE(float, uint8_t, true); } else if (key.dtype() == at::ScalarType::Half) { @@ -270,15 +278,17 @@ void reshape_and_cache( namespace vllm { template -__global__ void convert_fp8_e5m2_kernel( +__global__ void convert_fp8_kernel( const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); +#elif defined(ENABLE_FP8_E4M3) + dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); #else assert(false); #endif @@ -287,16 +297,25 @@ __global__ void convert_fp8_e5m2_kernel( } // namespace vllm -#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \ - vllm::convert_fp8_e5m2_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ +#define CALL_CONVERT_FP8(Tout, Tin) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), \ block_stride); -void convert_fp8_e5m2( +void convert_fp8( torch::Tensor& src_cache, torch::Tensor& dst_cache) { + torch::Device src_device = src_cache.device(); + torch::Device dst_device = dst_cache.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK( + src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + int64_t num_blocks = src_cache.size(0); int64_t block_stride = src_cache.stride(0); @@ -305,16 +324,16 @@ void convert_fp8_e5m2( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8_E5M2(uint8_t, float); + CALL_CONVERT_FP8(uint8_t, float); } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t); + CALL_CONVERT_FP8(uint8_t, uint16_t); } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16); + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8_E5M2(float, uint8_t); + CALL_CONVERT_FP8(float, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t); + CALL_CONVERT_FP8(uint16_t, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t); + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); } } diff --git a/csrc/ops.h b/csrc/ops.h index d5d6e240..41ecc1e8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -14,7 +14,8 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + float kv_scale); void paged_attention_v2( torch::Tensor& out, @@ -31,7 +32,8 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + float kv_scale); void rms_norm( torch::Tensor& out, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a5c6439f..de02afc1 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -91,9 +91,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &reshape_and_cache, "Reshape the key and value tensors and cache them"); cache_ops.def( - "convert_fp8_e5m2", - &convert_fp8_e5m2, - "Convert the key and value cache to fp8_e5m2 data type"); + "convert_fp8", + &convert_fp8, + "Convert the key and value cache to fp8 data type"); // Cuda utils pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); diff --git a/csrc/quantization/fp8/amd_detail/hip_float8.h b/csrc/quantization/fp8/amd_detail/hip_float8.h new file mode 100644 index 00000000..87c7c9ce --- /dev/null +++ b/csrc/quantization/fp8/amd_detail/hip_float8.h @@ -0,0 +1,167 @@ +#pragma once + +#ifdef __HIPCC__ +#include +#else +#include +#include +#include +#include +#endif + +#include "hip_float8_impl.h" + +struct alignas(1) hip_fp8 +{ + struct from_bits_t + { + }; + HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } + uint8_t data; + + hip_fp8() = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; + explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) + : data(v) + { + } + +#ifdef __HIP__MI300__ + // NOTE: ON-DEVICE... always optimal bias + explicit HIP_FP8_DEVICE hip_fp8(float v) + : data(hip_fp8_impl::to_fp8_from_fp32(v)) + { + } + + explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) + : hip_fp8(static_cast(v)) + { + } + + // Host only implementation using s/w simulation + explicit HIP_FP8_HOST +#else // __HIP__MI300__ + // both Host and DEVICE for non-MI300 using s/w simulation + explicit HIP_FP8_HOST_DEVICE +#endif // __HIP__MI300__ + hip_fp8(float v) + { + data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); + } + + explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) + : hip_fp8(static_cast(v)) + { + } + +#ifdef __HIP__MI300__ + // upcast using device specific intrinsic + explicit inline HIP_FP8_DEVICE operator float() const + { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + + return fval; + } + + explicit inline HIP_FP8_HOST operator float() const +#else // __HIP__MI300__ + explicit inline HIP_FP8_HOST_DEVICE operator float() const +#endif // __HIP__MI300__ + { + return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); + } +}; + +namespace std +{ +inline hip_fp8 sin(hip_fp8 a) +{ + return hip_fp8(sinf(float(a))); +} +inline hip_fp8 cos(hip_fp8 a) +{ + return hip_fp8(cosf(float(a))); +} +HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) +{ + return a; +} +} // namespace std + +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) +{ + return os << float(f8); +} + +// all + operator overloading with mixed types +// mixed types, always converts to f32, does computation in f32, and returns float +inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) +{ + return (fa + float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) +{ + return (float(a) + fb); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) +{ + return hip_fp8(float(a) + float(b)); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) +{ + return a = hip_fp8(float(a) + float(b)); +} + +// overloading multiplication, always returns float, +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) +{ + return float(a) * float(b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) +{ + return (a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) +{ + return (float(a) * b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) +{ + return ((float)a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) +{ + return ((float)a * float(b)); +} + +// overloading for compare +inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) +{ + return (a.data == b.data); +} +inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) +{ + return (a.data != b.data); +} + +inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) +{ + return static_cast(a) >= static_cast(b); +} +inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) +{ + return static_cast(a) > static_cast(b); +} diff --git a/csrc/quantization/fp8/amd_detail/hip_float8_impl.h b/csrc/quantization/fp8/amd_detail/hip_float8_impl.h new file mode 100644 index 00000000..e05905b4 --- /dev/null +++ b/csrc/quantization/fp8/amd_detail/hip_float8_impl.h @@ -0,0 +1,316 @@ +#pragma once + +#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#define __HIP__MI300__ +#endif + +#ifdef __HIPCC__ +#define HIP_FP8_HOST_DEVICE __host__ __device__ +#define HIP_FP8_HOST __host__ +#define HIP_FP8_DEVICE __device__ +#else +#define HIP_FP8_HOST_DEVICE +#define HIP_FP8_HOST +#define HIP_FP8_DEVICE +#endif + +namespace hip_fp8_impl +{ + +#ifdef __HIP__MI300__ +HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) +{ + uint8_t i8data; + union { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + + if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + + return i8data; +} +#endif // __HIP__MI300__ + +HIP_FP8_HOST inline int clz(uint32_t x) +{ + return __builtin_clz(x); +} +#if defined(__HIPCC__) || defined(__CUDA_ARCH__) +HIP_FP8_DEVICE inline int clz(uint32_t x) +{ + return __clz(x); +} +#endif + +template +HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) +{ +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + uint32_t x; + if (sizeof(T) == 4) { + x = reinterpret_cast(_x); + } else { + x = reinterpret_cast(_x); + } + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return 0x80; + } + } else { + // if(__hisinf(x) || __hisnan(x)) + if ((x & 0x7C00) == 0x7C00) { + return 0x80; + } + } + } else { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } else { + if ((x & 0x7C00) == 0x7C00) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + } + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we +mostly concern fp16 here. In this case, f8 is usually in denormal. But there +could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has +exponent bias 16. It means that there are some numbers in fp16 denormal but they +are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 +(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal + range. For example fp8 nanoo mode, denormal exponent is -7, but if the + fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no difference + // for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + static_cast(1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part + and make something not midpoint look like midpoint. For example, the fp16 + number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after + shift right by 4 bits, it would look like midpoint. +*/ + + if (exponent_diff > 0) { + mantissa >>= exponent_diff; + } else if (exponent_diff == -1) { + mantissa <<= -exponent_diff; + } + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that + // is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1 << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } else { + if ((1 << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; + } + } + + if (f8_exponent == 0 && mantissa == 0) { + return negative_zero_nan ? 0 : (sign << 7); + } + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +template +inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) +{ +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported"); + + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + + T fInf, fNegInf, fNaN, fNeg0; + +#ifdef __HIPCC__ + if (is_half) { + const uint16_t ihInf = 0x7C00; + const uint16_t ihNegInf = 0xFC00; + const uint16_t ihNaN = 0x7C01; + const uint16_t ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else +#endif + if (is_float) { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) { + return fNaN; + } + } else { + if (x == 0x80) { + return fNeg0; + } + if (exponent == ((1 << we) - 1)) { + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + } + typename std::conditional::type retval; + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; + return reinterpret_cast(retval); + } + + const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) { + retval = (sign << 15) | (exponent << 10) | mantissa; + } else { + retval = (sign << 31) | (exponent << 23) | mantissa; + } + return reinterpret_cast(retval); +} + +} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/amd_detail/quant_utils.cuh b/csrc/quantization/fp8/amd_detail/quant_utils.cuh new file mode 100644 index 00000000..89416097 --- /dev/null +++ b/csrc/quantization/fp8/amd_detail/quant_utils.cuh @@ -0,0 +1,517 @@ +#pragma once +#include "hip_float8.h" + +#include +#include +#include + +#include "../../../attention/dtype_float32.cuh" +#include "../../../attention/dtype_bfloat16.cuh" + +namespace vllm +{ +namespace fp8_e4m3 { +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} + +template +__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale) +{ + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0]; + tmp.h2r.y.data = f2[1]; + return tmp.ui32; +#else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = vec_conversion(static_cast(a)); + tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); + return tmp.u32; +#endif +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) +{ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const uint2& a) +{ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f); +} + +using __nv_bfloat162 = __hip_bfloat162; + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) +{ + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) +{ + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) +{ + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float vec_conversion(const uint8_t& a) +{ + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 vec_conversion(const uint16_t& a) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0]; + res.y = f2[1]; + return res; +#else + float2 res; + res.x = vec_conversion(static_cast(a)); + res.y = vec_conversion(static_cast(a >> 8U)); + return res; +#endif +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) +{ + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion(const uint2& a) +{ + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) +{ + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data)}; + return f8.data; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) +{ + hip_fp8 res{__bfloat162float(a)}; + return res.data; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const float& a) +{ + hip_fp8 f8(a); + return f8.data; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion(const uint32_t& a) +{ + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +// float2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +// Float4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) +{ + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; +} + +// Float4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion(const Float4_& a) +{ + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +// Float8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +// float2 -> bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) +{ + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; +} + +// Float4 -> bfloat162x2 +template <> +__inline__ __device__ bf16_4_t vec_conversion(const Float4_& a) +{ + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; +} + +// Float8 -> bfloat162x4 +template <> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_& a) +{ + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; +} + + +/* Scaled and vectorized conversions, for data exchange between high and low precision domains + + Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) + s.t. + Quantize(HP / scale) => FP8 + Dequant(FP8) * scale => HP + + */ + +// fp8 -> half +template <> +__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, const float scale) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, const float scale) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; +#else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion(static_cast(a >> 8U), scale); + return tmp.u32; +#endif +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, const float scale) +{ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, const float scale) +{ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); +} + +using __nv_bfloat162 = __hip_bfloat162; + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale) +{ + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, const float scale) +{ + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, const float scale) +{ + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, const float scale) +{ + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, const float scale) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; +#else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), scale); + return res; +#endif +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ scaled_vec_conversion(const uint32_t& a, const float scale) +{ + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, const float scale) +{ + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + + +/* Quantize(HP / scale) => FP8 */ + +// TODO(Hai): vectorized to add + +// half -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, const float scale) +{ + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data)/scale}; + return f8.data; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, const float scale) +{ + hip_fp8 res{__bfloat162float(a)/scale}; + return res.data; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, const float scale) +{ + hip_fp8 f8(a/scale); + return f8.data; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, const float scale) +{ + Float4_ tmp = scaled_vec_conversion(a, scale); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +} +} // namespace vllm diff --git a/docs/source/index.rst b/docs/source/index.rst index 39040920..5d5d5269 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -91,7 +91,8 @@ Documentation :caption: Quantization quantization/auto_awq - quantization/fp8_e5m2_kv_cache + quantization/fp8_e5m2_kvcache + quantization/fp8_e4m3_kvcache .. toctree:: :maxdepth: 2 diff --git a/docs/source/quantization/fp8_e4m3_kvcache.rst b/docs/source/quantization/fp8_e4m3_kvcache.rst new file mode 100644 index 00000000..fd71c00b --- /dev/null +++ b/docs/source/quantization/fp8_e4m3_kvcache.rst @@ -0,0 +1,49 @@ +.. _fp8_e4m3_kvcache: + +FP8 E4M3 KV Cache +================== + +Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, +improving throughput. OCP (Open Compute Project www.opencompute.org) specifies two common 8-bit floating point data formats: E5M2 +(5 exponent bits and 2 mantissa bits) and E4M3FN (4 exponent bits and 3 mantissa bits), often shortened as E4M3. One benefit of +the E4M3 format over E5M2 is that floating point numbers are represented in higher precision. However, the small dynamic range of +FP8 E4M3 (±240.0 can be represented) typically necessitates the use of a higher-precision (typically FP32) scaling factor alongside +each quantized tensor. For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling +factors of a finer granularity (e.g. per-channel). + +These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If +this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an +unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO). + +To install AMMO (AlgorithMic Model Optimization): + +.. code-block:: console + + $ pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo + +Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent silicon +offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to and from fp32, fp16, bf16, etc. +Thus, LLM inference is greatly accelerated with minimal accuracy loss. + + +Here is an example of how to enable this feature: + +.. code-block:: python + + # two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to + # https://github.com/vllm-project/vllm/blob/main/examples/fp8/README.md to generate kv_cache_scales.json of your own. + + from vllm import LLM, SamplingParams + sampling_params = SamplingParams(temperature=1.3, top_p=0.8) + llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", + kv_cache_dtype="fp8", + quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") + prompt = "London is the capital of" + out = llm.generate(prompt, sampling_params)[0].outputs[0].text + print(out) + + # output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial, + # output w/o scaling factors: England, located in the southeastern part of the country. It is known + +Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type. + diff --git a/docs/source/quantization/fp8_e5m2_kv_cache.rst b/docs/source/quantization/fp8_e5m2_kvcache.rst similarity index 83% rename from docs/source/quantization/fp8_e5m2_kv_cache.rst rename to docs/source/quantization/fp8_e5m2_kvcache.rst index f1eeb595..337252a0 100644 --- a/docs/source/quantization/fp8_e5m2_kv_cache.rst +++ b/docs/source/quantization/fp8_e5m2_kvcache.rst @@ -1,4 +1,4 @@ -.. _fp8_e5m2_kv_cache: +.. _fp8_kv_cache: FP8 E5M2 KV Cache ================== @@ -21,7 +21,7 @@ Here is an example of how to enable this feature: # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. - llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2") + llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -31,3 +31,6 @@ Here is an example of how to enable this feature: generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type. + diff --git a/examples/fp8/README.md b/examples/fp8/README.md new file mode 100644 index 00000000..84ad76c7 --- /dev/null +++ b/examples/fp8/README.md @@ -0,0 +1,96 @@ +# FP8 KV Cache + +This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms. + +## Prerequisites + +- Python 3.x +- PyTorch +- NumPy +- Hugging Face Transformers +- Hugging Face Hub +- AMMO + +Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps: +1. Install all necessary prerequisites and dependencies. +2. Convert HF model into a quantized HF model. +3. Extract KV Cache Scaling Factors from quantized HF model. +4. Load KV Cache Scaling Factors into VLLM. + +### 2. Convert HF model into a quantized HF model. +Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md). + +`quantize.py` (examples/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format). + +The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/fp8/quantizer/README.md`. + +### 3. Extract KV Cache Scaling Factors from quantized HF model. +`extract_scales.py` (examples/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following: +1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename. + +2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM. + +3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks. + +```python +# prerequisites: +# - Quantized HF LLaMa 2 model +python3 examples/fp8/extract_scales.py --help +Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE] + +KV Scale Extraction Example + +optional arguments: +--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU). +Optional arguments: +--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None) +--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto) +--revision: Specify the model's revision number. (Default: None) +--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None) +--output_name: Specify the output filename. (Default: kv_cache_scales.json) +--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None) +``` +```python +Example: +python3 examples/fp8/extract_scales.py --quantized_model --tp_size --output_dir +``` +### 4. Load KV Cache Scaling Factors into VLLM. +This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8. +```python +# prerequisites: +# - LLaMa 2 kv_cache_scales.json file + +python3 benchmarks/benchmark_throughput.py --help +usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL] + [--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] + [--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code] + [--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}] + [--quantization-param-path KV_CACHE_quantization_param_path] + +Benchmark Throughput Example +optional arguments: + -h, --help show this help message and exit + --backend {vllm,hf,mii} + --dataset DATASET Path to the dataset. + --input-len INPUT_LEN Input prompt length for each request + --output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset. + --model MODEL + --tokenizer TOKENIZER + --quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None} + --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE + --n N Number of generated sequences per prompt. + --use-beam-search + --num-prompts NUM_PROMPTS Number of prompts to process. + --seed SEED + --hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend. + --trust-remote-code trust remote code from huggingface + --max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model. + --dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. + --enforce-eager enforce eager execution + --kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria. + --quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria. +``` +``` +Example: +python3 benchmarks/benchmark_throughput.py --input-len --output-len -tp --kv-cache-dtype fp8 --quantization-param-path --model +```python diff --git a/examples/fp8/extract_scales.py b/examples/fp8/extract_scales.py new file mode 100644 index 00000000..5e5b3126 --- /dev/null +++ b/examples/fp8/extract_scales.py @@ -0,0 +1,367 @@ +import argparse +import glob +import json +import os +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple + +import numpy as np +import torch +from safetensors.torch import safe_open + +from vllm.model_executor.layers.quantization.schema import QuantParamSchema + + +# Adapted from vllm/model_executor/weight_utils.py +# The main differences are that we add the NPZ format and simplify +# its functionality drastically for our purposes (e.g. we assume that +# the quantized model exists locally and there is no need to download it) +def _prepare_hf_weights( + quantized_model_dir: str, + load_format: str = "auto", + fall_back_to_pt: bool = True, +) -> Tuple[str, List[str], bool]: + if not os.path.isdir(quantized_model_dir): + raise FileNotFoundError( + f"The quantized model directory `{quantized_model_dir}` " + "does not exist.") + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == "auto": + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == "safetensors": + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == "pt": + allow_patterns = ["*.pt"] + elif load_format == "npz": + allow_patterns = ["*.npz"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob( + os.path.join(quantized_model_dir, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if not use_safetensors: + # Exclude files that are not needed for inference. + # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{quantized_model_dir}`") + + return hf_weights_files, use_safetensors + + +# Adapted from vllm/model_executor/weight_utils.py +def _hf_tensorfile_iterator(filename: str, load_format: str, + use_safetensors: bool): + if load_format == "npz": + assert not use_safetensors + with np.load(filename) as data: + for name in data.files: + param = torch.from_numpy(data[name]) + yield name, param + elif use_safetensors: + with safe_open(filename, framework="pt") as f: + for name in f.keys(): # NOQA: SIM118 + param = f.get_tensor(name) + yield name, param + else: + state = torch.load(filename, map_location="cpu") + for name, param in state.items(): + yield name, param + del state + torch.cuda.empty_cache() + + +def _kv_scales_extractor( + hf_tensor_files: Iterable[str], + use_safetensors: bool, + rank_keyword: str = "rank", + expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]: + """ + Given a list of files containing tensor data, attempt to extract KV cache + scales from these files. Intended as a helper function taking in the output + from _prepare_hf_weights. + Args: + rank_keyword Matches the number immediately after this keyword in the + tensor filename to determine the TP rank corresponding + to said tensor file + expected_tp_size If specified, the TP size of the tensor files is checked + against this and an error is raised if they don't match. + Returns a dictionary mapping TP ranks to their relevant KV cache scales. + The per-rank scales are themselves represented as a dictionary of layer + indices to the respective per-layer scale. + """ + for char in rank_keyword: + assert not char.isdecimal( + ), f"Rank keyword {rank_keyword} contains a numeric character!" + rank_scales_map = {} + for tensor_file in hf_tensor_files: + try: + rank_idx = tensor_file.find(rank_keyword) + if rank_idx != -1: + start_idx = rank_idx + len(rank_keyword) + stop_idx = start_idx + while stop_idx < len( + tensor_file) and tensor_file[stop_idx].isdecimal(): + stop_idx += 1 + if stop_idx == start_idx: + raise RuntimeError("Did not find rank # in filename.") + rank = int(tensor_file[start_idx:stop_idx]) + elif len(hf_tensor_files) == 1: + # Since there is only one tensor file, we can assume + # that it's intended for TP rank 0 + rank = 0 + else: + raise RuntimeError( + f"Filename does not contain '{rank_keyword}'.") + except RuntimeError: + print("Unable to determine TP rank " + f"corresponding to file '{tensor_file}'") + raise + + if rank not in rank_scales_map: + layer_scales_map = {} + rank_scales_map[rank] = layer_scales_map + else: + raise RuntimeError( + f"Tensor file '{tensor_file}' shares TP rank {rank} " + "with another tensor file.") + + module_delimiter = ":" if args.load_format == "npz" else "." + for name, param in _hf_tensorfile_iterator(tensor_file, + args.load_format, + use_safetensors): + if "kv_cache_scaling_factor" in name: + nums = [ + int(s) for s in name.split(module_delimiter) + if s.isdecimal() + ] + assert len( + nums) == 1, f"Could not determine layer idx for {name}" + layer_idx = nums[0] + assert layer_idx not in layer_scales_map, f"Duplicate scaling"\ + f" factor corresponding to layer {layer_idx}" + try: + layer_scales_map[layer_idx] = param.item() + except RuntimeError: + print( + "This utility supports only per-tensor scalar scales " + f"for now. The tensor\n {name} = {param} \nis an " + "invalid scale factor.") + raise + + if all( + len(layer_scales_map) == 0 + for layer_scales_map in rank_scales_map.values()): + # Note: this is true even if the rank_scales_map is empty + print("WARNING: No KV cache scale factors found. No output saved.") + return None + empirical_tp_world_size = max(rank_scales_map.keys()) + 1 + if expected_tp_size is not None: + assert expected_tp_size == empirical_tp_world_size, \ + f"User expected TP world size = {expected_tp_size} " \ + "from model but tool is expecting TP world size = " \ + f"{empirical_tp_world_size} from model instead." + for i in range(empirical_tp_world_size): + assert i in rank_scales_map, "Expected TP world size = "\ + f"{empirical_tp_world_size} but did not find KV " \ + f"cache scaling factors for TP rank {i}" + print(f"Found TP world size = {empirical_tp_world_size} " + "when extracting KV cache scales!") + return rank_scales_map + + +def _metadata_extractor(quantized_model_dir: str, + metadata_extract_fns: \ + Dict[str, Callable[[Dict[str, Any]], Any]]) \ + -> Dict[str, Any]: + """ + Given a directory containing quantized model files, this function + aims to extract metadata from the JSON files within this directory. + Each JSON file is expected to represent a dictionary in JSON + format (referred to as a "JSON-dictionary"). Metadata extraction is + defined by a dictionary called metadata_extract_fns, where each + metadata field name is mapped to an extraction function. + + These extraction functions are designed to take a JSON-dictionary + as their only argument and return the corresponding metadata. + While extraction functions are permitted to raise exceptions, they + should only raise a KeyError or ValueError if the metadata field + cannot be extracted from the current JSON-dictionary, yet there's + a possibility of finding it in another JSON-dictionary. + + The function returns a dictionary that maps metadata fields to + their extracted data. The keys of this dictionary correspond exactly + to those in metadata_extract_fns. If any fields fail to be extracted, + their corresponding values are set to None, and a warning is printed. + """ + if not os.path.isdir(quantized_model_dir): + raise FileNotFoundError( + f"The quantized model directory `{quantized_model_dir}` " + "does not exist.") + metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json")) + + result = {} + for file in metadata_files: + with open(file) as f: + try: + metadata = json.load(f) + except json.JSONDecodeError: + print(f"Could not parse `{file}` as a valid metadata file," + " skipping it.") + continue + if not isinstance(metadata, dict): + print(f"The file `{file}` does not correspond to a " + "JSON-serialized dictionary, skipping it.") + continue + for metadata_name, extract_fn in metadata_extract_fns.items(): + try: + metadata_info = extract_fn(metadata) + if metadata_name not in result: + result[metadata_name] = metadata_info + elif metadata_info != result[metadata_name]: + raise RuntimeError( + "Metadata mismatch! Originally found " + f"{metadata_name} = {result[metadata_name]} but " + f"now found {metadata_name} = {metadata_info} in " + f"`{file}`") + except KeyError: + # It is possible that a given file does not contain some + # of our selected metadata as it could be located in some + # other metadata file. + # 'EFINAE': extract_fn failure is not an error. + pass + except ValueError: + # See above. + pass + + # Warn if we cannot find any of the requested metadata + for metadata_name in metadata_extract_fns: + if metadata_name not in result: + print("WARNING: Unable to find requested metadata field " + f"`{metadata_name}`, setting it to None.") + result[metadata_name] = None + + return result + + +def main(args): + metadata_extract_fns = { + "model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"], + "tp_size": lambda json_dict: int(json_dict["tensor_parallel"]), + "model_dtype": lambda json_dict: json_dict["dtype"] + } + recovered_metadata = _metadata_extractor(args.quantized_model, + metadata_extract_fns) + if args.tp_size is not None: + metadata_tp_size = recovered_metadata["tp_size"] + if metadata_tp_size is not None: + assert args.tp_size == metadata_tp_size, \ + f"User expected TP world size = {args.tp_size} " \ + f"but found TP world size = {metadata_tp_size} from metadata!" + expected_tp_size = args.tp_size or recovered_metadata["tp_size"] + rank_keyword = "rank" + hf_tensor_files, use_safetensors = _prepare_hf_weights( + args.quantized_model, args.load_format) + rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors, + rank_keyword, expected_tp_size) + # Postprocess: formatting to the current schema. Consider pulling it + # out into a dedicated function should it ever become more complicated. + rank_scales_map = { + rank: {k: scale[k] + for k in sorted(scale.keys())} + for rank, scale in rank_scales_map.items() + } + # TODO: Expand this with activation and weights scaling factors when + # they are used in the future + schema = QuantParamSchema( + model_type=recovered_metadata["model_type"], + kv_cache={ + "dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else + recovered_metadata["model_dtype"]), + "scaling_factor": + rank_scales_map + }, + ) + + if args.output_dir is None: + output_file = os.path.join(args.quantized_model, args.output_name) + else: + if not os.path.isdir(args.output_dir): + os.makedirs(args.output_dir, exist_ok=True) + output_file = os.path.join(args.output_dir, args.output_name) + + with open(output_file, 'w') as f: + f.write(schema.model_dump_json(indent=4)) + print(f"Completed! KV cache scaling factors saved to {output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This simple utility extracts the " + "KV cache scaling factors from a quantized HF model " + "and saves them to a JSON file compatible with later " + "use by vLLM (pass this file to the appropriate " + "runtime typically using the argument " + "--quantization-param-path ). This is only used " + "if the KV cache dtype is FP8 and on ROCm (AMD GPU).") + parser.add_argument( + "--quantized_model", + help="Specify the directory containing a single quantized HF model. " + "It is expected that the quantization format is FP8_E4M3, for use " + "on ROCm (AMD GPU).", + required=True) + parser.add_argument( + "--load_format", + help="Optionally specify the format of the model's tensor files " + "containing the KV cache scaling factors.", + choices=["auto", "safetensors", "npz", "pt"], + default="auto") + parser.add_argument( + "--output_dir", + help="Optionally specify the output directory. By default the " + "KV cache scaling factors will be saved in the model directory, " + "however you can override this behavior here.", + default=None) + parser.add_argument( + "--output_name", + help="Optionally specify the output filename.", + # TODO: Change this once additional scaling factors are enabled + default="kv_cache_scales.json") + parser.add_argument( + "--tp_size", + help="Optionally specify the tensor-parallel (TP) size that the " + "quantized model should correspond to. If specified, during KV " + "cache scaling factor extraction the observed TP size will be " + "checked against this and an error will be raised if there is " + "a mismatch. If not specified, the quantized model's expected " + "TP size is instead inferred from the largest TP rank observed. " + "The expected TP size is cross-checked against the TP ranks " + "observed in the quantized model and an error is raised if any " + "discrepancies are found.", + default=None, + type=int) + args = parser.parse_args() + + main(args) diff --git a/examples/fp8/quantizer/README.md b/examples/fp8/quantizer/README.md new file mode 100644 index 00000000..8f89a74a --- /dev/null +++ b/examples/fp8/quantizer/README.md @@ -0,0 +1,32 @@ +### Quantizer Utilities +`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM: +`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py` + +### Prerequisite + +#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later +`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo` + +#### AMMO Download (code and docs) +`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz` +`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz` + +### Usage + +#### Run on H100 system for speed if FP8; number of GPUs depends on the model size + +#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache: +`python quantize.py --model_dir ./ll2-7b --dtype float16 --qformat fp8 --kv_cache_dtype fp8 --output_dir ./ll2_7b_fp8 --calib_size 512 --tp_size 1` + +Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference) +``` +# ll ./ll2_7b_fp8/ +total 19998244 +drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./ +drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../ +-rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json +-rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz +-rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors +# +``` + diff --git a/examples/fp8/quantizer/quantize.py b/examples/fp8/quantizer/quantize.py new file mode 100644 index 00000000..1ff56706 --- /dev/null +++ b/examples/fp8/quantizer/quantize.py @@ -0,0 +1,369 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from examples/quantization/hf_ptq.py +""" + +import argparse +import copy +import json +import random +import time + +import ammo.torch.quantization as atq +import numpy as np +import torch +from ammo.torch.export import export_model_config +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer + +RAND_SEED = 1234 +MAX_SEQ_LEN = 2048 + +EMPTY_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "enable": False, + }, + "*input_quantizer": { + "enable": False + }, + "*lm_head*": { + "enable": False + }, + "*output_layer*": { + "enable": False + }, + "default": { + "enable": False + }, + }, + "algorithm": "max", +} + +KV_CACHE_CFG = { + "*.query_key_value.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.Wqkv.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.W_pack.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.c_attn.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.k_proj.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.v_proj.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, +} + +QUANT_CFG_CHOICES = { + "int8_sq": atq.INT8_SMOOTHQUANT_CFG, + "fp8": atq.FP8_DEFAULT_CFG, + "int4_awq": atq.INT4_AWQ_CFG, + "w4a8_awq": atq.W4A8_AWQ_BETA_CFG, + "int8_wo": EMPTY_CFG, + "int4_wo": EMPTY_CFG, + "full_prec": EMPTY_CFG, +} + +MODEL_NAME_PATTERN_MAP = { + "GPT2": "gpt2", + "Xverse": "llama", + "Llama": "llama", + "Mistral": "llama", + "GPTJ": "gptj", + "FalconForCausalLM": "falcon", + "RWForCausalLM": "falcon", + "baichuan": "baichuan", + "MPT": "mpt", + "Bloom": "bloom", + "ChatGLM": "chatglm", + "QWen": "qwen", +} + + +def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None): + print(f"Initializing tokenizer from {ckpt_path}") + tokenizer = AutoTokenizer.from_pretrained( + ckpt_path, + model_max_length=max_seq_len, + padding_side="left", + trust_remote_code=True, + ) + if model_type and model_type == "qwen": + # qwen use token id 151643 as pad and eos tokens + tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) + tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) + + # can't set attribute 'pad_token' for "" + if tokenizer.pad_token != "": + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + assert tokenizer.pad_token is not None, f"Pad token for {model_type} cannot be set!" + + return tokenizer + + +def get_model(ckpt_path, dtype="fp16", device="cuda"): + print(f"Initializing model from {ckpt_path}") + if dtype == "bf16" or dtype == "bfloat16": + dtype = torch.bfloat16 + elif dtype == "fp16" or dtype == "float16": + dtype = torch.float16 + elif dtype == "fp32" or dtype == "float32": + dtype = torch.float32 + else: + raise NotImplementedError(f"Unknown dtype {dtype}") + + # model_kwargs = {"torch_dtype": dtype} + model_kwargs = {"torch_dtype": "auto"} + + model = AutoModelForCausalLM.from_pretrained(ckpt_path, + device_map="auto", + **model_kwargs, + trust_remote_code=True) + model.eval() + + model_dtype = next(model.parameters()).dtype + if dtype != model_dtype: + print( + f"[TensorRT-LLM][WARNING] The manually set model data type is {dtype}, " + f"but the data type of the HuggingFace model is {model_dtype}.") + + return model + + +def get_model_type(model): + for k, v in MODEL_NAME_PATTERN_MAP.items(): + if k.lower() in type(model).__name__.lower(): + return v + return None + + +def get_calib_dataloader(data="cnn_dailymail", + tokenizer=None, + batch_size=1, + calib_size=512, + block_size=512, + device=None): + print("Loading calibration dataset") + if data == "pileval": + dataset = load_dataset( + "json", + data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", + split="train") + dataset = dataset["text"][:calib_size] + elif data == "cnn_dailymail": + dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") + dataset = dataset["article"][:calib_size] + else: + raise NotImplementedError + + batch_encoded = tokenizer.batch_encode_plus(dataset, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=block_size) + if device: + batch_encoded = batch_encoded.to(device) + batch_encoded = batch_encoded["input_ids"] + + calib_dataloader = DataLoader(batch_encoded, + batch_size=batch_size, + shuffle=False) + + return calib_dataloader + + +def quantize_model(model, quant_cfg, calib_dataloader=None): + + def calibrate_loop(): + if calib_dataloader is None: + return + """Adjusts weights and scaling factors based on selected algorithms.""" + for idx, data in enumerate(calib_dataloader): + print(f"Calibrating batch {idx}") + model(data) + + print("Starting quantization...") + start_time = time.time() + atq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + end_time = time.time() + print("Quantization done. Total time used: {:.2f} s.".format(end_time - + start_time)) + + return model + + +def main(args): + if not torch.cuda.is_available(): + raise EnvironmentError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + + model = get_model(args.model_dir, args.dtype, args.device) + model_type = get_model_type(model) + tokenizer = get_tokenizer(args.model_dir, model_type=model_type) + + if args.qformat in ["full_prec", "int8_wo", "int4_wo" + ] and args.kv_cache_dtype is None: + print(f"No quantization applied, export {args.dtype} model") + else: + if "awq" in args.qformat: + if args.calib_size > 32: + print( + f"AWQ calibration could take longer with calib_size = {args.calib_size}, Using" + " calib_size=32 instead") + args.calib_size = 32 + print( + "\nAWQ calibration could take longer than other calibration methods. Please" + " increase the batch size to speed up the calibration process. Batch size can be" + " set by adding the argument --batch_size to the command line.\n" + ) + + calib_dataloader = get_calib_dataloader( + tokenizer=tokenizer, + batch_size=args.batch_size, + calib_size=args.calib_size, + device=args.device, + ) + + if args.qformat in QUANT_CFG_CHOICES: + quant_cfg = QUANT_CFG_CHOICES[args.qformat] + else: + raise ValueError( + f"Unsupported quantization format: {args.qformat}") + + if "awq" in args.qformat: + quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat]) + weight_quantizer = quant_cfg["quant_cfg"][ + "*weight_quantizer"] # type: ignore + if isinstance(weight_quantizer, list): + weight_quantizer = weight_quantizer[0] + weight_quantizer["block_sizes"][-1] = args.awq_block_size + + if args.kv_cache_dtype is not None: + if args.kv_cache_dtype == "fp8": + for value in KV_CACHE_CFG.values(): + value.update({"num_bits": (4, 3)}) # type: ignore + quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore + + print(quant_cfg) + + model = quantize_model(model, quant_cfg, calib_dataloader) + + with torch.inference_mode(): + if model_type is None: + print( + f"Unknown model type {type(model).__name__}. Continue exporting..." + ) + model_type = f"unknown:{type(model).__name__}" + + export_path = args.output_dir + start_time = time.time() + + if args.qformat == "int4_awq" and model_type == "qwen": + torch.save(model.state_dict(), export_path) + else: + export_npz = (model_type not in [ + 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan' + ]) + + # export safetensors + export_model_config( + model, + model_type, + getattr(torch, args.dtype), + export_dir=export_path, + inference_tensor_parallel=args.tp_size, + inference_pipeline_parallel=args.pp_size, + # export_tensorrt_llm_config=(not export_npz), + export_tensorrt_llm_config=False, + export_npz=export_npz) + + # Workaround for wo quantization + if args.qformat in ["int8_wo", "int4_wo", "full_prec"]: + with open(f"{export_path}/config.json", 'r') as f: + tensorrt_llm_config = json.load(f) + if args.qformat == "int8_wo": + tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' + elif args.qformat == "int4_wo": + tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16' + else: + tensorrt_llm_config["quantization"]["quant_algo"] = None + with open(f"{export_path}/config.json", "w") as f: + json.dump(tensorrt_llm_config, f, indent=4) + + end_time = time.time() + print("Quantized model exported to {} \nTotal time used {:.2f} s.". + format(export_path, end_time - start_time)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_dir", + help="Specify where the HuggingFace model is", + required=True) + parser.add_argument("--device", default="cuda") + parser.add_argument("--dtype", help="Model data type.", default="float16") + parser.add_argument( + "--qformat", + help="Quantization format.", + default="full_prec", + choices=[ + "fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo", + "full_prec" + ], + ) + parser.add_argument("--batch_size", + help="Batch size for calibration.", + type=int, + default=1) + parser.add_argument("--calib_size", + help="Number of samples for calibration.", + type=int, + default=512) + parser.add_argument("--output_dir", default="exported_model") + parser.add_argument("--tp_size", type=int, default=1) + parser.add_argument("--pp_size", type=int, default=1) + parser.add_argument("--awq_block_size", type=int, default=128) + parser.add_argument("--kv_cache_dtype", + help="KV Cache dtype.", + default=None, + choices=["int8", "fp8", None]) + args = parser.parse_args() + + main(args) diff --git a/pyproject.toml b/pyproject.toml index 9d042601..b7ad8b8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,10 @@ build-backend = "setuptools.build_meta" [tool.ruff] # Allow lines to be as long as 80. line-length = 80 +exclude = [ + # External file, leaving license intact + "examples/fp8/quantizer/quantize.py" +] [tool.ruff.lint] select = [ diff --git a/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json new file mode 100644 index 00000000..a548f0a9 --- /dev/null +++ b/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json @@ -0,0 +1,90 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.0230364128947258, + "1": 0.01979283057153225, + "2": 0.0241350457072258, + "3": 0.0308314748108387, + "4": 0.0430733822286129, + "5": 0.0370396226644516, + "6": 0.0306222103536129, + "7": 0.0357491634786129, + "8": 0.0358189195394516, + "9": 0.0443289652466774, + "10": 0.0433175228536129, + "11": 0.0416782945394516, + "12": 0.0366908498108387, + "13": 0.0432477705180645, + "14": 0.0410505048930645, + "15": 0.0457589291036129, + "16": 0.0418526791036129, + "17": 0.0432477705180645, + "18": 0.0469447560608387, + "19": 0.0514787957072258, + "20": 0.0541294664144516, + "21": 0.0587681382894516, + "22": 0.0625, + "23": 0.0585588738322258, + "24": 0.0600237175822258, + "25": 0.0588030144572258, + "26": 0.0531180277466774, + "27": 0.06396484375, + "28": 0.0603027381002903, + "29": 0.0582101047039032, + "30": 0.0625348836183548, + "31": 0.0585588738322258, + "32": 0.0582798570394516, + "33": 0.0575125589966774, + "34": 0.0590820349752903, + "35": 0.0614188089966774, + "36": 0.0631975457072258, + "37": 0.0615931935608387, + "38": 0.0601283498108387, + "39": 0.0571986623108387, + "40": 0.0670340433716774, + "41": 0.0523507259786129, + "42": 0.0547223798930645, + "43": 0.0631975457072258, + "44": 0.0663713738322258, + "45": 0.0603376142680645, + "46": 0.0652204304933548, + "47": 0.0734514519572258, + "48": 0.0693708211183548, + "49": 0.0725446492433548, + "50": 0.0627790242433548, + "51": 0.0691266804933548, + "52": 0.0688825398683548, + "53": 0.068429134786129, + "54": 0.0605119988322258, + "55": 0.0799386203289032, + "56": 0.0853097140789032, + "57": 0.0661969929933548, + "58": 0.0689871683716774, + "59": 0.0724051371216774, + "60": 0.0541643425822258, + "61": 0.0626743882894516, + "62": 0.0628487765789032, + "63": 0.0607212632894516, + "64": 0.0589076466858387, + "65": 0.0451660193502903, + "66": 0.0453055277466774, + "67": 0.0414341539144516, + "68": 0.0385044664144516, + "69": 0.0414341539144516, + "70": 0.0466308631002903, + "71": 0.0399693101644516, + "72": 0.0437011756002903, + "73": 0.0434221550822258, + "74": 0.0428989976644516, + "75": 0.0401785746216774, + "76": 0.0431082621216774, + "77": 0.0484444759786129, + "78": 0.0417829267680645, + "79": 0.0418178029358387 + } + } + } +} \ No newline at end of file diff --git a/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json new file mode 100644 index 00000000..bb734039 --- /dev/null +++ b/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json @@ -0,0 +1,42 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.0152239128947258, + "1": 0.0188860222697258, + "2": 0.0354178324341774, + "3": 0.0376674123108387, + "4": 0.0418526791036129, + "5": 0.0433175228536129, + "6": 0.0397600457072258, + "7": 0.0424455925822258, + "8": 0.0415387861430645, + "9": 0.0408412404358387, + "10": 0.0395856611430645, + "11": 0.0377371683716774, + "12": 0.0400739423930645, + "13": 0.040771484375, + "14": 0.0393415205180645, + "15": 0.0369001142680645, + "16": 0.03857421875, + "17": 0.0387486070394516, + "18": 0.0403180830180645, + "19": 0.0396205373108387, + "20": 0.0375627800822258, + "21": 0.0407366082072258, + "22": 0.0432477705180645, + "23": 0.0377022884786129, + "24": 0.0399693101644516, + "25": 0.0374581478536129, + "26": 0.0413295216858387, + "27": 0.0442243330180645, + "28": 0.0424804724752903, + "29": 0.0456891767680645, + "30": 0.0409109964966774, + "31": 0.0482352152466774 + } + } + } +} diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index b03fecff..03ea7292 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -32,7 +32,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256 BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] -KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] +KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -172,6 +172,9 @@ def test_paged_attention( device) key_cache, value_cache = key_caches[0], value_caches[0] + # Using default kv_scale + kv_scale = 1.0 + # Call the paged attention kernel. output = torch.empty_like(query) if version == "v1": @@ -188,6 +191,7 @@ def test_paged_attention( max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) elif version == "v2": num_partitions = ((max_context_len + PARTITION_SIZE - 1) // @@ -219,12 +223,13 @@ def test_paged_attention( max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) else: raise AssertionError(f"Unknown version: {version}") # Run the reference implementation. - if kv_cache_dtype == "fp8_e5m2": + if kv_cache_dtype == "fp8": # Convert cache data back to dtype. x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, @@ -232,14 +237,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache) + cache_ops.convert_fp8(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache) + cache_ops.convert_fp8(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) @@ -263,7 +268,8 @@ def test_paged_attention( # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. - if kv_cache_dtype == "fp8_e5m2": + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": atol, rtol = 1e-2, 1e-5 assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 0cdb92f2..4141aaca 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,6 +5,7 @@ import pytest import torch from vllm._C import cache_ops +from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -23,7 +24,7 @@ SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] +KV_CACHE_DTYPE = ["auto", "fp8"] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -105,6 +106,7 @@ def test_copy_blocks( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @torch.inference_mode() def test_reshape_and_cache( kv_cache_factory, @@ -116,7 +118,10 @@ def test_reshape_and_cache( dtype: torch.dtype, seed: int, device: str, + kv_cache_dtype: str, ) -> None: + if not is_hip() and kv_cache_dtype == "fp8": + pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -132,17 +137,33 @@ def test_reshape_and_cache( # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, - num_heads, head_size, dtype, - None, seed, device) + num_heads, head_size, + kv_cache_dtype, dtype, seed, + device) key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() + if kv_cache_dtype == "fp8": + cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + cache_ops.convert_fp8(key_cache, cloned_key_cache) + cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + cache_ops.convert_fp8(value_cache, cloned_value_cache) + else: + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Using default kv_scale + kv_scale = 1.0 # Call the reshape_and_cache kernel. cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, "auto") + slot_mapping, kv_cache_dtype, kv_scale) + + if kv_cache_dtype == "fp8": + result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + cache_ops.convert_fp8(key_cache, result_key_cache) + result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + cache_ops.convert_fp8(value_cache, result_value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -156,8 +177,18 @@ def test_reshape_and_cache( cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i] - assert torch.allclose(key_cache, cloned_key_cache) - assert torch.allclose(value_cache, cloned_value_cache) + if kv_cache_dtype == "fp8": + assert torch.allclose(result_key_cache, + cloned_key_cache, + atol=0.001, + rtol=0.1) + assert torch.allclose(result_value_cache, + cloned_value_cache, + atol=0.001, + rtol=0.1) + else: + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) @pytest.mark.parametrize("direction", COPYING_DIRECTION) @@ -169,6 +200,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @torch.inference_mode() def test_swap_blocks( kv_cache_factory, @@ -181,7 +213,12 @@ def test_swap_blocks( dtype: torch.dtype, seed: int, device: str, + kv_cache_dtype: str, ) -> None: + if kv_cache_dtype == "fp8" and "cpu" in direction: + pytest.skip() + if not is_hip() and kv_cache_dtype == "fp8": + pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -202,13 +239,13 @@ def test_swap_blocks( # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed, - src_device) + num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, + seed, src_device) # Create the KV caches on the second device. dist_key_caches, dist_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed, - dst_device) + num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, + seed, dst_device) src_key_caches_clone = src_key_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone() @@ -223,3 +260,40 @@ def test_swap_blocks( dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()) + + +@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3") +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_fp8_conversion( + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + low = -224.0 + high = 224.0 + shape = (num_blocks, num_heads, head_size, block_size) + cache = torch.empty(shape, dtype=dtype, device=device) + cache.uniform_(low, high) + + cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) + cache_ops.convert_fp8(cache, cache_fp8) + + converted_cache = torch.empty_like(cache) + cache_ops.convert_fp8(cache_fp8, converted_cache) + + assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a7e0ab92..a03cf2dd 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -81,5 +81,6 @@ class AttentionImpl(ABC): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + kv_scale: float, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e50d5237..4e0d9d14 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -156,6 +156,7 @@ class FlashAttentionImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, + kv_scale: float, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -184,7 +185,8 @@ class FlashAttentionImpl(AttentionImpl): PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype) + attn_metadata.kv_cache_dtype, + kv_scale) if attn_metadata.is_prompt: # Prompt run. @@ -207,6 +209,9 @@ class FlashAttentionImpl(AttentionImpl): ) else: # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. output = PagedAttention.forward_prefix( query, key, @@ -233,6 +238,7 @@ class FlashAttentionImpl(AttentionImpl): self.num_kv_heads, self.scale, self.alibi_slopes, + kv_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 8e510f97..d349c3ef 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -178,6 +178,7 @@ class XFormersImpl(AttentionImpl): value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: XFormersMetadata, + kv_scale: float, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -205,7 +206,8 @@ class XFormersImpl(AttentionImpl): PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype) + attn_metadata.kv_cache_dtype, + kv_scale) if attn_metadata.is_prompt: # Prompt run. @@ -259,6 +261,9 @@ class XFormersImpl(AttentionImpl): query, key, value, attn_metadata) else: # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. output = PagedAttention.forward_prefix( query, key, @@ -285,6 +290,7 @@ class XFormersImpl(AttentionImpl): self.num_kv_heads, self.scale, self.alibi_slopes, + kv_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 2e0aa18e..9856654f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -42,5 +42,7 @@ class Attention(nn.Module): value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, + kv_scale: float = 1.0, ) -> torch.Tensor: - return self.impl.forward(query, key, value, kv_cache, attn_metadata) + return self.impl.forward(query, key, value, kv_cache, attn_metadata, + kv_scale) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 5901af4f..ec2c18dc 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -73,6 +73,7 @@ class PagedAttention: value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, + kv_scale: float, ) -> None: cache_ops.reshape_and_cache( key, @@ -81,6 +82,7 @@ class PagedAttention: value_cache, slot_mapping.flatten(), kv_cache_dtype, + kv_scale, ) @staticmethod @@ -95,6 +97,7 @@ class PagedAttention: num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + kv_scale, ) -> torch.Tensor: output = torch.empty_like(query) @@ -126,6 +129,7 @@ class PagedAttention: max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) else: # Run PagedAttention V2. @@ -157,6 +161,7 @@ class PagedAttention: max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) return output diff --git a/vllm/config.py b/vllm/config.py index ef680c69..e27c8eb4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -60,6 +60,11 @@ class ModelConfig: output). If None, will be derived from the model. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. + quantization_param_path: Path to JSON file containing scaling factors. + Used to load KV cache scaling factors into the model when KV cache + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the + model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. @@ -83,6 +88,7 @@ class ModelConfig: tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_logprobs: int = 5, @@ -98,6 +104,7 @@ class ModelConfig: self.code_revision = code_revision self.tokenizer_revision = tokenizer_revision self.quantization = quantization + self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture self.max_logprobs = max_logprobs @@ -369,21 +376,20 @@ class CacheConfig: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype == "fp8_e5m2": - if is_hip(): - raise NotImplementedError( - "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.") - nvcc_cuda_version = get_nvcc_cuda_version() - if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"): - raise ValueError( - "FP8 is not supported when cuda version is lower than 11.8." - ) + elif self.cache_dtype == "fp8": + if not is_hip(): + nvcc_cuda_version = get_nvcc_cuda_version() + if nvcc_cuda_version < Version("11.8"): + raise ValueError( + "FP8 is not supported when cuda version is" + "lower than 11.8.") logger.info( - "Using fp8_e5m2 data type to store kv cache. It reduces " - "the GPU memory footprint and boosts the performance. " - "But it may cause slight accuracy drop. " - "Currently we only support fp8 without scaling factors and " - "make e5m2 as a default format.") + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "But it may cause slight accuracy drop without scaling " + "factors. FP8_E5M2 (without scaling) is only supported on " + "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 " + "is instead supported for common inference criteria.") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9c60a936..a6197942 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -21,6 +21,7 @@ class EngineArgs: load_format: str = 'auto' dtype: str = 'auto' kv_cache_dtype: str = 'auto' + quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -159,11 +160,23 @@ class EngineArgs: parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8_e5m2'], + choices=['auto', 'fp8'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' - 'data type. Note FP8 is not supported when cuda version is ' - 'lower than 11.8.') + 'data type. FP8_E5M2 (without scaling) is only supported on cuda ' + 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' + 'supported for common inference criteria. ') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache ' + 'scaling factors. This should generally be supplied, when ' + 'KV cache dtype is FP8. Otherwise, KV cache scaling factors ' + 'default to 1.0, which may cause accuracy issues. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version' + 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' + 'supported for common inference criteria. ') parser.add_argument('--max-model-len', type=int, default=EngineArgs.max_model_len, @@ -408,8 +421,8 @@ class EngineArgs: self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, - self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs) + self.quantization_param_path, self.enforce_eager, + self.max_context_len_to_capture, self.max_logprobs) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2da2c79e..5c343921 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -97,6 +97,7 @@ class LLMEngine: f"quantization={model_config.quantization}, " f"enforce_eager={model_config.enforce_eager}, " f"kv_cache_dtype={cache_config.cache_dtype}, " + f"quantization_param_path={model_config.quantization_param_path}, " f"device_config={device_config.device}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. diff --git a/vllm/model_executor/layers/quantization/schema.py b/vllm/model_executor/layers/quantization/schema.py new file mode 100644 index 00000000..a26c5247 --- /dev/null +++ b/vllm/model_executor/layers/quantization/schema.py @@ -0,0 +1,84 @@ +""" +This file contains the Pydantic schemas for various quantization-related +parameters. When a relevant quantization technique is specified, these +parameters are loaded in the form of a JSON alongside the model weights +and augment the model with additional information needed for use of that +technique. The format of this JSON should be specified by one or more +schemas contained here. + +For example, when the KV cache is quantized to FP8-E4M3 (currently only +possible on ROCm), the model can be optionally augmented with KV cache +scaling factors. +""" + +from typing import Dict, Optional + +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator + + +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!") + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}.") + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}.") + for i in range(tp_size): + assert i in self.scaling_factor, ( + f"KV cache scales map for TP rank {i} not found.") + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}.") + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!") + return self diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 57857deb..ef19c41e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -41,11 +41,13 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) + hf_model_weights_iterator, + kv_cache_scales_loader) from vllm.sequence import SamplerOutput +from vllm.utils import is_hip class LlamaMLP(nn.Module): @@ -115,6 +117,15 @@ class LlamaAttention(nn.Module): self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + # This will be overwritten by model initialization if we are using it. + # N.B. currently we only support per tensor scalar scaling factors + # & only applicable to ROCm (AMD GPU). + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + self.kv_scale = 1.0 + self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -153,7 +164,8 @@ class LlamaAttention(nn.Module): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata, + self.kv_scale) output, _ = self.o_proj(attn_output) return output @@ -402,3 +414,27 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, tp_rank, tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 9181f298..2c0dd8ce 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -5,7 +5,7 @@ import hashlib import json import os from collections import defaultdict -from typing import Any, Iterator, List, Optional, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Tuple import filelock import numpy as np @@ -18,6 +18,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) +from vllm.model_executor.layers.quantization.schema import QuantParamSchema logger = init_logger(__name__) @@ -275,6 +276,46 @@ def hf_model_weights_iterator( torch.cuda.empty_cache() +def kv_cache_scales_loader( + filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, + model_type: Optional[str]) -> Iterable[Tuple[int, float]]: + """ + A simple utility to read in KV cache scaling factors that have been + previously serialized to disk. Used by the model to populate the appropriate + KV cache scaling factors. The serialization should represent a dictionary + whose keys are the TP ranks and values are another dictionary mapping layers + to their KV cache scaling factors. + Keep this function in sync with the output of examples/fp8/extract_scales.py + """ + try: + with open(filename) as f: + context = { + "model_type": model_type, + "num_hidden_layers": num_hidden_layers, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + schema_dct = json.load(f) + schema = QuantParamSchema.model_validate(schema_dct, + context=context) + layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] + return layer_scales_map.items() + + except FileNotFoundError: + logger.error(f"File or directory '{filename}' not found.") + except json.JSONDecodeError: + logger.error(f"Error decoding JSON in file '{filename}'.") + except Exception as e: + logger.error(f"An error occurred while reading '{filename}': {e}") + # This section is reached if and only if any of the excepts are hit + # Return an empty iterable (list) => no KV cache scales are loaded + # which ultimately defaults to 1.0 scales + logger.warning("Defaulting to KV cache scaling factors = 1.0 " + f"for all layers in TP rank {tp_rank} " + "as an error occurred during loading.") + return [] + + def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: """convert PySafeSlice object from safetensors to torch.Tensor diff --git a/vllm/utils.py b/vllm/utils.py index 3b229f11..380ffe76 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -25,7 +25,7 @@ STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, - "fp8_e5m2": torch.uint8, + "fp8": torch.uint8, } @@ -266,7 +266,7 @@ def get_nvcc_cuda_version() -> Optional[Version]: return nvcc_cuda_version -def _generate_random_fp8_e5m2( +def _generate_random_fp8( tensor: torch.tensor, low: float, high: float, @@ -282,7 +282,7 @@ def _generate_random_fp8_e5m2( from vllm._C import cache_ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - cache_ops.convert_fp8_e5m2(tensor_tmp, tensor) + cache_ops.convert_fp8(tensor_tmp, tensor) del tensor_tmp @@ -311,7 +311,7 @@ def create_kv_caches_with_random( raise ValueError(f"Invalid model dtype: {model_dtype}") elif cache_dtype in ["half", "bfloat16", "float"]: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] - elif cache_dtype == "fp8_e5m2": + elif cache_dtype == "fp8": torch_dtype = torch.uint8 else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") @@ -328,10 +328,10 @@ def create_kv_caches_with_random( key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) - if cache_dtype == 'fp8_e5m2': - _generate_random_fp8_e5m2(key_cache, -scale, scale) - elif torch_dtype in [torch.half, torch.bfloat16, torch.float]: + if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_cache, -scale, scale) else: raise ValueError( f"Does not support key cache of type {cache_dtype}") @@ -343,10 +343,10 @@ def create_kv_caches_with_random( value_cache = torch.empty(size=value_cache_shape, dtype=torch_dtype, device=device) - if cache_dtype == 'fp8_e5m2': - _generate_random_fp8_e5m2(value_cache, -scale, scale) - elif torch_dtype in [torch.half, torch.bfloat16, torch.float]: + if cache_dtype in ["auto", "half", "bfloat16", "float"]: value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(value_cache, -scale, scale) else: raise ValueError( f"Does not support value cache of type {cache_dtype}") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 31fa5247..86ca6f9c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,7 +23,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, +from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) @@ -120,6 +120,26 @@ class ModelRunner: self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) + if self.kv_cache_dtype == "fp8" and is_hip(): + # Currently scaled KV cache is only enabled on ROCm + if self.model_config.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + self.model.load_kv_cache_scales( + self.model_config.quantization_param_path) + else: + raise RuntimeError("Using FP8 KV cache and scaling " + "factors provided but model " + f"{self.model.__class__} does not " + "support loading scaling factors.") + else: + logger.warn("Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!") + elif self.model_config.quantization_param_path is not None: + logger.warn("KV cache scaling factors provided, " + "but the KV cache data type is not FP8. " + "KV cache scaling factors will not be used.") + def set_block_size(self, block_size: int) -> None: self.block_size = block_size