From 9090bf02e74334a8020b454814e0d00fa780fd79 Mon Sep 17 00:00:00 2001 From: zhaoyang-star Date: Mon, 29 Jan 2024 08:43:54 +0800 Subject: [PATCH] Support FP8-E5M2 KV Cache (#2279) Co-authored-by: zhaoyang Co-authored-by: Zhuohan Li --- benchmarks/benchmark_latency.py | 8 + benchmarks/benchmark_throughput.py | 12 +- .../kernels/benchmark_paged_attention.py | 33 ++- csrc/attention/attention_dtypes.h | 1 + csrc/attention/attention_kernels.cu | 259 ++++++++++------ csrc/attention/dtype_fp8_e5m2.cuh | 35 +++ csrc/cache.h | 8 +- csrc/cache_kernels.cu | 135 +++++++-- csrc/dispatch_utils.h | 10 + csrc/ops.h | 6 +- csrc/pybind.cpp | 4 + .../fp8_e5m2_kvcache/quant_utils.cuh | 278 ++++++++++++++++++ .../source/quantization/fp8_e5m2_kv_cache.rst | 32 ++ setup.py | 3 + tests/kernels/conftest.py | 41 +-- tests/kernels/test_attention.py | 41 ++- tests/kernels/test_cache.py | 10 +- vllm/config.py | 29 +- vllm/engine/arg_utils.py | 11 +- vllm/engine/llm_engine.py | 6 + vllm/model_executor/input_metadata.py | 6 +- vllm/model_executor/layers/attention.py | 3 + vllm/utils.py | 110 ++++++- vllm/worker/cache_engine.py | 15 +- vllm/worker/model_runner.py | 7 + vllm/worker/worker.py | 5 +- 26 files changed, 912 insertions(+), 196 deletions(-) create mode 100644 csrc/attention/dtype_fp8_e5m2.cuh create mode 100644 csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh create mode 100644 docs/source/quantization/fp8_e5m2_kv_cache.rst diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index d75d690c..71731343 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -24,6 +24,7 @@ def main(args: argparse.Namespace): trust_remote_code=args.trust_remote_code, dtype=args.dtype, enforce_eager=args.enforce_eager, + kv_cache_dtype=args.kv_cache_dtype, ) sampling_params = SamplingParams( @@ -117,6 +118,13 @@ if __name__ == '__main__': parser.add_argument('--enforce-eager', action='store_true', help='enforce eager mode and disable CUDA graph') + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=['auto', 'fp8_e5m2'], + default='auto', + help= + 'Data type for kv cache storage. If "auto", will use model data type.') parser.add_argument( '--profile', action='store_true', diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 3aac479c..d45d3330 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -71,6 +71,7 @@ def run_vllm( dtype: str, max_model_len: Optional[int], enforce_eager: bool, + kv_cache_dtype: str, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -83,6 +84,7 @@ def run_vllm( dtype=dtype, max_model_len=max_model_len, enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, ) # Add the requests to the engine. @@ -206,7 +208,8 @@ def main(args: argparse.Namespace): args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager) + args.max_model_len, args.enforce_eager, + args.kv_cache_dtype) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -284,6 +287,13 @@ if __name__ == "__main__": parser.add_argument("--enforce-eager", action="store_true", help="enforce eager execution") + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8_e5m2"], + default="auto", + help= + 'Data type for kv cache storage. If "auto", will use model data type.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 935393e9..56fe1b92 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -1,9 +1,11 @@ +from typing import Optional import argparse import random import time import torch +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm._C import ops NUM_BLOCKS = 1024 @@ -23,6 +25,7 @@ def main( dtype: torch.dtype, seed: int, do_profile: bool, + kv_cache_dtype: Optional[str] = None, ) -> None: random.seed(seed) torch.random.manual_seed(seed) @@ -59,15 +62,10 @@ def main( block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") # Create the KV cache. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) - key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda") - key_cache.uniform_(-scale, scale) - value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) - value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device="cuda") - value_cache.uniform_(-scale, scale) + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, + dtype) + key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. output = torch.empty_like(query) @@ -106,6 +104,7 @@ def main( block_size, max_context_len, alibi_slopes, + kv_cache_dtype, ) elif version == "v2": ops.paged_attention_v2( @@ -123,6 +122,7 @@ def main( block_size, max_context_len, alibi_slopes, + kv_cache_dtype, ) else: raise ValueError(f"Invalid version: {version}") @@ -168,16 +168,18 @@ if __name__ == '__main__': default="half") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8_e5m2"], + default="auto", + help= + 'Data type for kv cache storage. If "auto", will use model data type.') args = parser.parse_args() print(args) if args.num_query_heads % args.num_kv_heads != 0: raise ValueError("num_query_heads must be divisible by num_kv_heads") - dtype_to_torch_dtype = { - "half": torch.half, - "bfloat16": torch.bfloat16, - "float": torch.float, - } main( version=args.version, num_seqs=args.batch_size, @@ -187,7 +189,8 @@ if __name__ == '__main__': head_size=args.head_size, block_size=args.block_size, use_alibi=args.use_alibi, - dtype=dtype_to_torch_dtype[args.dtype], + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], seed=args.seed, do_profile=args.profile, + kv_cache_dtype=args.kv_cache_dtype, ) diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index 88b4edde..61748e6b 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -4,3 +4,4 @@ #include "dtype_float16.cuh" #include "dtype_float32.cuh" #include "dtype_bfloat16.cuh" +#include "dtype_fp8_e5m2.cuh" diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 9dcacfbe..a5ddeac7 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -25,6 +25,7 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" +#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" #include @@ -79,17 +80,19 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Grid: (num_heads, num_seqs, max_num_partitions). template< typename scalar_t, + typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, + bool IS_FP8_E5M2_KV_CACHE, int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] @@ -145,6 +148,9 @@ __device__ void paged_attention_kernel( constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; +#ifdef ENABLE_FP8_E5M2 + using Quant_vec = typename Vec::Type; +#endif constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; @@ -176,7 +182,7 @@ __device__ void paged_attention_kernel( // x == THREAD_GROUP_SIZE * VEC_SIZE // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(scalar_t); + constexpr int x = 16 / sizeof(cache_t); float qk_max = -FLT_MAX; // Iterate over the key blocks. @@ -202,13 +208,23 @@ __device__ void paged_attention_kernel( #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + if constexpr (IS_FP8_E5M2_KV_CACHE) { +#ifdef ENABLE_FP8_E5M2 + Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + // Vector conversion from Quant_vec to K_vec. + k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); +#else + assert(false); +#endif + } else { + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } } // Compute dot product. @@ -282,6 +298,9 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; +#ifdef ENABLE_FP8_E5M2 + using V_quant_vec = typename Vec::Type; +#endif using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; @@ -307,14 +326,25 @@ __device__ void paged_attention_kernel( L_vec logits_vec; from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); - const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec = *reinterpret_cast(v_ptr + offset); + V_vec v_vec; + if constexpr (IS_FP8_E5M2_KV_CACHE) { +#ifdef ENABLE_FP8_E5M2 + V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); +#else + assert(false); +#endif + } else { + v_vec = *reinterpret_cast(v_ptr + offset); + } if (block_idx == num_context_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. @@ -395,14 +425,16 @@ __device__ void paged_attention_kernel( // Grid: (num_heads, num_seqs, 1). template< typename scalar_t, + typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, - int NUM_THREADS> + int NUM_THREADS, + bool IS_FP8_E5M2_KV_CACHE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] @@ -412,7 +444,7 @@ __global__ void paged_attention_v1_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { - paged_attention_kernel( + paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); @@ -421,17 +453,19 @@ __global__ void paged_attention_v1_kernel( // Grid: (num_heads, num_seqs, max_num_partitions). template< typename scalar_t, + typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, + bool IS_FP8_E5M2_KV_CACHE, int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] @@ -441,7 +475,7 @@ __global__ void paged_attention_v2_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { - paged_attention_kernel( + paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); @@ -550,10 +584,10 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ + ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ + vllm::paged_attention_v1_kernel<<>>( \ out_ptr, \ query_ptr, \ key_cache_ptr, \ @@ -571,7 +605,9 @@ __global__ void paged_attention_v2_reduce_kernel( // TODO(woosuk): Tune NUM_THREADS. template< typename T, + typename CACHE_T, int BLOCK_SIZE, + bool IS_FP8_E5M2_KV_CACHE, int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, @@ -602,8 +638,8 @@ void paged_attention_v1_launcher( T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); @@ -647,35 +683,35 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ alibi_slopes); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v1( @@ -689,20 +725,36 @@ void paged_attention_v1( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8_e5m2") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } } #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ + vllm::paged_attention_v2_kernel \ <<>>( \ exp_sums_ptr, \ max_logits_ptr, \ @@ -730,7 +782,9 @@ void paged_attention_v1( template< typename T, + typename CACHE_T, int BLOCK_SIZE, + bool IS_FP8_E5M2_KV_CACHE, int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( @@ -768,8 +822,8 @@ void paged_attention_v2_launcher( float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); @@ -816,38 +870,38 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \ + paged_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ alibi_slopes); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( @@ -864,15 +918,30 @@ void paged_attention_v2( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8_e5m2") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } } diff --git a/csrc/attention/dtype_fp8_e5m2.cuh b/csrc/attention/dtype_fp8_e5m2.cuh new file mode 100644 index 00000000..0580fbb8 --- /dev/null +++ b/csrc/attention/dtype_fp8_e5m2.cuh @@ -0,0 +1,35 @@ +#pragma once + +#include "attention_generic.cuh" + +#include +#ifdef ENABLE_FP8_E5M2 +#include +#endif + +namespace vllm { +#ifdef ENABLE_FP8_E5M2 +// fp8 vector types for quantization of kv cache + +template<> +struct Vec { + using Type = uint8_t; +}; + +template<> +struct Vec { + using Type = uint16_t; +}; + +template<> +struct Vec { + using Type = uint32_t; +}; + +template<> +struct Vec { + using Type = uint2; +}; +#endif // ENABLE_FP8_E5M2 + +} // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index b26faad2..21c71830 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -20,7 +20,8 @@ void reshape_and_cache( torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping); + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); void gather_cached_kv( torch::Tensor& key, @@ -28,3 +29,8 @@ void gather_cached_kv( torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping); + +// Just for unittest +void convert_fp8_e5m2( + torch::Tensor& src_cache, + torch::Tensor& dst_cache); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index b7523cb4..fe0159e4 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,6 +4,7 @@ #include "cuda_compat.h" #include "dispatch_utils.h" +#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" #include #include @@ -131,7 +132,7 @@ void copy_blocks( dim3 block(std::min(1024, numel_per_block)); const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), @@ -143,12 +144,12 @@ void copy_blocks( namespace vllm { -template +template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, @@ -185,19 +186,45 @@ __global__ void reshape_and_cache_kernel( + head_idx * head_size * block_size + head_offset * block_size + block_offset; - key_cache[tgt_key_idx] = key[src_key_idx]; - value_cache[tgt_value_idx] = value[src_value_idx]; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (is_fp8_e5m2_kv_cache) { +#ifdef ENABLE_FP8_E5M2 + key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); + value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); +#else + assert(false); +#endif + } else { + key_cache[tgt_key_idx] = tgt_key; + value_cache[tgt_value_idx] = tgt_value; + } } } } // namespace vllm +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ + vllm::reshape_and_cache_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + key_stride, \ + value_stride, \ + num_heads, \ + head_size, \ + block_size, \ + x); + void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping) // [num_tokens] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) { int num_tokens = key.size(0); int num_heads = key.size(1); @@ -212,23 +239,25 @@ void reshape_and_cache( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), - "reshape_and_cache_kernel", - [&] { - vllm::reshape_and_cache_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - slot_mapping.data_ptr(), - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); - }); + if (kv_cache_dtype == "auto") { + if (key.dtype() == at::ScalarType::Float) { + CALL_RESHAPE_AND_CACHE(float, float, false); + } else if (key.dtype() == at::ScalarType::Half) { + CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); + } else if (key.dtype() == at::ScalarType::BFloat16) { + CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); + } + } else if (kv_cache_dtype == "fp8_e5m2") { + if (key.dtype() == at::ScalarType::Float) { + CALL_RESHAPE_AND_CACHE(float, uint8_t, true); + } else if (key.dtype() == at::ScalarType::Half) { + CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); + } else if (key.dtype() == at::ScalarType::BFloat16) { + CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); + } + } else { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } } namespace vllm { @@ -256,12 +285,12 @@ __global__ void gather_cached_kv_kernel( for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { const int tgt_key_idx = token_idx * key_stride + i; const int tgt_value_idx = token_idx * value_stride + i; - + const int head_idx = i / head_size; const int head_offset = i % head_size; const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension const int x_offset = head_offset % x; - + const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x @@ -373,7 +402,7 @@ void gather_cached_kv( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( key.scalar_type(), "gather_cached_kv_kernel_optimized", [&] { @@ -391,3 +420,55 @@ void gather_cached_kv( x); }); } + +namespace vllm { + +template +__global__ void convert_fp8_e5m2_kernel( + const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const int64_t block_stride) { + const int64_t block_idx = blockIdx.x; + for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { + int64_t idx = block_idx * block_stride + i; +#ifdef ENABLE_FP8_E5M2 + dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); +#else + assert(false); +#endif + } +} + +} // namespace vllm + +#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \ + vllm::convert_fp8_e5m2_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), \ + block_stride); + +void convert_fp8_e5m2( + torch::Tensor& src_cache, + torch::Tensor& dst_cache) +{ + int64_t num_blocks = src_cache.size(0); + int64_t block_stride = src_cache.stride(0); + + dim3 grid(num_blocks); + dim3 block(std::min(block_stride, int64_t(512))); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8_E5M2(uint8_t, float); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8_E5M2(float, uint8_t); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t); + } +} diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 0ae9cd64..85fdfc09 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -14,3 +14,13 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) diff --git a/csrc/ops.h b/csrc/ops.h index 6e996fd0..ce77dd47 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -13,7 +13,8 @@ void paged_attention_v1( torch::Tensor& context_lens, int block_size, int max_context_len, - const c10::optional& alibi_slopes); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); void paged_attention_v2( torch::Tensor& out, @@ -29,7 +30,8 @@ void paged_attention_v2( torch::Tensor& context_lens, int block_size, int max_context_len, - const c10::optional& alibi_slopes); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); void rms_norm( torch::Tensor& out, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index f94efadf..db2da8f0 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -75,6 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "gather_cached_kv", &gather_cached_kv, "Gather key and value from the cache into contiguous QKV tensors"); + cache_ops.def( + "convert_fp8_e5m2", + &convert_fp8_e5m2, + "Convert the key and value cache to fp8_e5m2 data type"); // Cuda utils pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); diff --git a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh b/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh new file mode 100644 index 00000000..c3b0d311 --- /dev/null +++ b/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh @@ -0,0 +1,278 @@ +#pragma once + +#include +#include +#include +#include +#include "../../attention/attention_dtypes.h" +#include "../../attention/dtype_float32.cuh" +#include "../../attention/dtype_float16.cuh" +#include "../../attention/dtype_bfloat16.cuh" + +#pragma once + +namespace vllm { +#ifdef ENABLE_FP8_E5M2 +namespace fp8_e5m2_unscaled { + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} + +// fp8 -> half +template<> +__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) +{ + __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); + return res.x; +} + +// fp8x2 -> half2 +template<> +__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) +{ + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; +} + +// fp8x4 -> half2x2 +template<> +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) +{ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template<> +__inline__ __device__ uint4 vec_conversion(const uint2& a) +{ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template<> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) +{ + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp); +} + +// fp8x2 -> __nv_bfloat162 +template<> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) +{ + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; +} + +// fp8x4 -> bf16_4_t +template<> +__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) +{ + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> bf16_8_t +template<> +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) +{ + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template<> +__inline__ __device__ float vec_conversion(const uint8_t& a) +{ + // fp8 -> half + uint16_t tmp = vec_conversion(a); + // half -> float + return half_to_float(tmp); +} + +// fp8x2 -> float2 +template<> +__inline__ __device__ float2 vec_conversion(const uint16_t& a) +{ + // fp8x2 -> half2 + uint32_t tmp = vec_conversion(a); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template<> +__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) +{ + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> float8 +template<> +__inline__ __device__ Float8_ vec_conversion(const uint2& a) +{ + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + + +// half -> fp8 +template<> +__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) +{ + __half_raw tmp; + tmp.x = a; + __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2); + return (uint8_t)res; +} + +// bf16 -> fp8 +template<> +__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template<> +__inline__ __device__ uint8_t vec_conversion(const float& a) +{ + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template<> +__inline__ __device__ float4 vec_conversion(const uint32_t& a) +{ + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + + +template<> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template<> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) +{ + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + + return b; +} + +template<> +__inline__ __device__ float4 vec_conversion(const Float4_& a) +{ + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template<> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +template<> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { + __nv_bfloat162 b; + from_float(b, a); + return b; +} + +template<> +__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { + bf16_4_t b; + from_float(b, a); + return b; +} + +template<> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { + bf16_8_t b; + from_float(b, a); + return b; +} + +} // namespace fp8_e5m2_unscaled +#endif // ENABLE_FP8_E5M2 +} // namespace vllm diff --git a/docs/source/quantization/fp8_e5m2_kv_cache.rst b/docs/source/quantization/fp8_e5m2_kv_cache.rst new file mode 100644 index 00000000..10437260 --- /dev/null +++ b/docs/source/quantization/fp8_e5m2_kv_cache.rst @@ -0,0 +1,32 @@ +.. _fp8_e5m2_kv_cache: + +FP8 E5M2 KV Cache +================== + +The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits. +The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other. + +Here is an example of how to enable this feature: + +.. code-block:: python + from vllm import LLM, SamplingParams + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + # Create an LLM. + llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2") + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + diff --git a/setup.py b/setup.py index 2f624269..b552cb67 100644 --- a/setup.py +++ b/setup.py @@ -253,6 +253,9 @@ if _is_cuda(): num_threads = min(os.cpu_count(), nvcc_threads) NVCC_FLAGS += ["--threads", str(num_threads)] + if nvcc_cuda_version >= Version("11.8"): + NVCC_FLAGS += ["-DENABLE_FP8_E5M2"] + # changes for punica kernels NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS REMOVE_NVCC_FLAGS = [ diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index fca97ab7..8c51bfc1 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,44 +1,7 @@ -from typing import List, Tuple - import pytest -import torch - - -def create_kv_caches( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - seed: int, - device: str, -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches = [] - for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) - key_cache.uniform_(-scale, scale) - key_caches.append(key_cache) - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches = [] - for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) - value_cache.uniform_(-scale, scale) - value_caches.append(value_cache) - return key_caches, value_caches +from vllm.utils import create_kv_caches_with_random @pytest.fixture() def kv_cache_factory(): - return create_kv_caches + return create_kv_caches_with_random diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 3949948e..cbb1d406 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,14 +6,16 @@ import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm._C import ops +from vllm._C import ops, cache_ops from vllm.utils import get_max_shared_memory_bytes FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -NUM_BLOCKS = 12000 # Arbitrary values for testing +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -23,6 +25,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] SEEDS = [0] DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @@ -105,6 +108,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("use_alibi", USE_ALIBI) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) def test_paged_attention( @@ -116,6 +120,7 @@ def test_paged_attention( use_alibi: bool, block_size: int, dtype: torch.dtype, + kv_cache_dtype: str, seed: int, device: int, ) -> None: @@ -158,8 +163,9 @@ def test_paged_attention( # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, dtype, - seed, gpu_id) + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + gpu_id) key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. @@ -177,6 +183,7 @@ def test_paged_attention( block_size, max_context_len, alibi_slopes, + kv_cache_dtype, ) elif version == "v2": num_partitions = ((max_context_len + PARTITION_SIZE - 1) // @@ -209,11 +216,30 @@ def test_paged_attention( block_size, max_context_len, alibi_slopes, + kv_cache_dtype, ) else: raise AssertionError(f"Unknown version: {version}") # Run the reference implementation. + if kv_cache_dtype == "fp8_e5m2": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=gpu_id) + cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=gpu_id) + cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache) + value_cache = dequantized_value_cache + ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( ref_output, @@ -230,7 +256,12 @@ def test_paged_attention( # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two # outputs. Thus, we use a relaxed tolerance for the test. - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8_e5m2": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) def ref_multi_query_kv_attention( diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 7b1cc058..193bc29b 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -15,6 +15,7 @@ NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -26,6 +27,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @torch.inference_mode() def test_copy_blocks( kv_cache_factory, @@ -38,6 +40,7 @@ def test_copy_blocks( dtype: torch.dtype, seed: int, device: int, + kv_cache_dtype: str, ) -> None: random.seed(seed) torch.random.manual_seed(seed) @@ -59,7 +62,8 @@ def test_copy_blocks( # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, num_layers, num_heads, - head_size, dtype, seed, gpu_id) + head_size, kv_cache_dtype, + dtype, seed, gpu_id) # Clone the KV caches. cloned_key_caches = [key_cache.clone() for key_cache in key_caches] @@ -124,7 +128,7 @@ def test_reshape_and_cache( # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, num_heads, head_size, dtype, - seed, gpu_id) + None, seed, gpu_id) key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. @@ -133,7 +137,7 @@ def test_reshape_and_cache( # Call the reshape_and_cache kernel. cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping) + slot_mapping, "auto") # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) diff --git a/vllm/config.py b/vllm/config.py index da97eaa7..197f20c1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,13 +1,14 @@ from typing import Optional, Union, ClassVar from dataclasses import dataclass import os +from packaging.version import Version import torch from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.transformers_utils.config import get_config -from vllm.utils import get_cpu_memory, is_hip +from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version logger = init_logger(__name__) @@ -275,6 +276,7 @@ class CacheConfig: gpu_memory_utilization: Fraction of GPU memory to use for the vLLM execution. swap_space: Size of the CPU swap space per GPU (in GiB). + cache_dtype: Data type for kv cache storage. """ def __init__( @@ -282,13 +284,16 @@ class CacheConfig: block_size: int, gpu_memory_utilization: float, swap_space: int, + cache_dtype: str, sliding_window: Optional[int] = None, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB + self.cache_dtype = cache_dtype self.sliding_window = sliding_window self._verify_args() + self._verify_cache_dtype() # Will be set after profiling. self.num_gpu_blocks = None @@ -300,6 +305,28 @@ class CacheConfig: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype == "fp8_e5m2": + nvcc_cuda_version = get_nvcc_cuda_version() + if nvcc_cuda_version < Version("11.8"): + raise ValueError( + "FP8 is not supported when cuda version is lower than 11.8." + ) + device_name = torch.cuda.get_device_name() + if "AMD" in device_name: + raise NotImplementedError( + "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.") + logger.info( + "Using fp8_e5m2 data type to store kv cache. It reduces " + "the GPU memory footprint and boosts the performance. " + "But it may cause slight accuracy drop. " + "Currently we only support fp8 without scaling factors and " + "make e5m2 as a default format.") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 968362c4..231ce332 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -17,6 +17,7 @@ class EngineArgs: download_dir: Optional[str] = None load_format: str = 'auto' dtype: str = 'auto' + kv_cache_dtype: str = 'auto' seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -122,6 +123,14 @@ class EngineArgs: 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8_e5m2'], + default='auto', + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. Note FP8 is not supported when cuda version is ' + 'lower than 11.8.') parser.add_argument('--max-model-len', type=int, default=None, @@ -269,7 +278,7 @@ class EngineArgs: self.max_context_len_to_capture) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, - self.swap_space, + self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window()) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 87752eea..5b73ef08 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -85,6 +85,7 @@ class LLMEngine: f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, " f"quantization={model_config.quantization}, " f"enforce_eager={model_config.enforce_eager}, " + f"kv_cache_dtype={cache_config.cache_dtype}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. @@ -144,6 +145,7 @@ class LLMEngine: rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) self._run_workers("init_model") @@ -234,6 +236,7 @@ class LLMEngine: model_config = copy.deepcopy(self.model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) + cache_config = copy.deepcopy(self.cache_config) for rank, (worker, (node_id, _)) in enumerate(zip(self.workers, @@ -249,6 +252,7 @@ class LLMEngine: rank, distributed_init_method, lora_config=self.lora_config, + cache_config=cache_config, )) driver_rank = 0 @@ -261,6 +265,7 @@ class LLMEngine: driver_rank, distributed_init_method, lora_config=self.lora_config, + cache_config=cache_config, is_driver_worker=True, ) @@ -306,6 +311,7 @@ class LLMEngine: block_size=self.cache_config.block_size, gpu_memory_utilization=self.cache_config.gpu_memory_utilization, cpu_swap_space=self.cache_config.swap_space_bytes, + cache_dtype=self.cache_config.cache_dtype, ) # Since we use a shared centralized controller, we take the minimum diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index ef49cc59..f0a88ac8 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -12,6 +12,7 @@ class InputMetadata: max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) + kv_cache_dtype: Data type to store kv cache. """ def __init__( @@ -25,6 +26,7 @@ class InputMetadata: context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, + kv_cache_dtype: str, ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens @@ -35,6 +37,7 @@ class InputMetadata: self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph + self.kv_cache_dtype = kv_cache_dtype # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. @@ -47,4 +50,5 @@ class InputMetadata: f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " - f"use_cuda_graph={self.use_cuda_graph})") + f"use_cuda_graph={self.use_cuda_graph}, " + f"kv_cache_dtype={self.kv_cache_dtype})") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 8b5c6ab3..91ed43f0 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -98,6 +98,7 @@ class PagedAttention(nn.Module): key_cache, value_cache, input_metadata.slot_mapping.flatten(), + input_metadata.kv_cache_dtype, ) if input_metadata.is_prompt: @@ -265,6 +266,7 @@ def _paged_attention( block_size, input_metadata.max_context_len, alibi_slopes, + input_metadata.kv_cache_dtype, ) else: # Run PagedAttention V2. @@ -295,5 +297,6 @@ def _paged_attention( block_size, input_metadata.max_context_len, alibi_slopes, + input_metadata.kv_cache_dtype, ) return output diff --git a/vllm/utils.py b/vllm/utils.py index 6a9508f6..dc817414 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,9 +1,11 @@ import enum import os import socket +import subprocess import uuid from platform import uname -from typing import List +from typing import List, Tuple, Union +from packaging.version import parse, Version import psutil import torch @@ -17,7 +19,17 @@ from typing import ( from collections import OrderedDict from typing import Any, Hashable, Optional +from vllm.logger import init_logger + T = TypeVar("T") +logger = init_logger(__name__) + +STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8_e5m2": torch.uint8, +} class Device(enum.Enum): @@ -167,3 +179,99 @@ def get_open_port() -> int: def set_cuda_visible_devices(device_ids: List[int]) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) + + +def get_nvcc_cuda_version() -> Version: + cuda_home = os.environ.get('CUDA_HOME') + if not cuda_home: + cuda_home = '/usr/local/cuda' + logger.info( + f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.' + ) + nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], + universal_newlines=True) + output = nvcc_output.split() + release_idx = output.index("release") + 1 + nvcc_cuda_version = parse(output[release_idx].split(",")[0]) + return nvcc_cuda_version + + +def _generate_random_fp8_e5m2( + tensor: torch.tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format repesents Inf. + # | E4M3 | E5M2 + #-----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm._C import cache_ops + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + cache_ops.convert_fp8_e5m2(tensor_tmp, tensor) + del tensor_tmp + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: Optional[int] = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in ["half", "bfloat16", "float"]: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + elif cache_dtype == "fp8_e5m2": + torch_dtype = torch.uint8 + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8_e5m2': + _generate_random_fp8_e5m2(key_cache, -scale, scale) + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8_e5m2': + _generate_random_fp8_e5m2(value_cache, -scale, scale) + value_caches.append(value_cache) + return key_caches, value_caches diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 1dd0243f..f57e1ed7 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,7 @@ import torch from vllm._C import cache_ops from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import in_wsl +from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE logger = init_logger(__name__) @@ -34,12 +34,16 @@ class CacheEngine: self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) - self.dtype = model_config.dtype self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + # Initialize the cache. self.gpu_cache = self.allocate_gpu_cache() self.cpu_cache = self.allocate_cpu_cache() @@ -142,6 +146,7 @@ class CacheEngine: @staticmethod def get_cache_block_size( block_size: int, + cache_dtype: str, model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: @@ -152,7 +157,11 @@ class CacheEngine: key_cache_block = block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) - dtype_size = _get_dtype_size(model_config.dtype) + if cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype_size = _get_dtype_size(dtype) return dtype_size * total diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 60f5b71d..2a12152a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -36,6 +36,7 @@ class ModelRunner: parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ): self.model_config = model_config @@ -68,6 +69,7 @@ class ModelRunner: self.graph_block_tables = None # Set after initial profiling. # cache in_wsl result self.in_wsl = in_wsl() + self.kv_cache_dtype = kv_cache_dtype def load_model(self) -> None: self.model = get_model(self.model_config, self.lora_config) @@ -223,6 +225,7 @@ class ModelRunner: context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, @@ -350,6 +353,7 @@ class ModelRunner: context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, + kv_cache_dtype=self.kv_cache_dtype, ) return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests @@ -473,6 +477,7 @@ class ModelRunner: "context_lens": input_metadata.context_lens, "block_tables": input_metadata.block_tables, "use_cuda_graph": input_metadata.use_cuda_graph, + "kv_cache_dtype": input_metadata.kv_cache_dtype, "selected_token_indices": sampling_metadata.selected_token_indices, "lora_requests": lora_requests, @@ -495,6 +500,7 @@ class ModelRunner: context_lens=metadata_dict["context_lens"], block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], + kv_cache_dtype=metadata_dict["kv_cache_dtype"], ) sampling_metadata = SamplingMetadata( seq_groups=None, @@ -665,6 +671,7 @@ class ModelRunner: context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, + kv_cache_dtype=self.kv_cache_dtype, ) if self.lora_config: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f1dad64b..a74adfa5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -37,6 +37,7 @@ class Worker: rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, + kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -54,6 +55,7 @@ class Worker: parallel_config, scheduler_config, lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). @@ -95,6 +97,7 @@ class Worker: block_size: int, gpu_memory_utilization: float, cpu_swap_space: int, + cache_dtype: str, ) -> Tuple[int, int]: """Profiles the peak memory usage of the model and returns the maximum number of GPU and CPU cache blocks that can be allocated. @@ -119,7 +122,7 @@ class Worker: peak_memory = total_gpu_memory - free_gpu_memory cache_block_size = CacheEngine.get_cache_block_size( - block_size, self.model_config, self.parallel_config) + block_size, cache_dtype, self.model_config, self.parallel_config) num_gpu_blocks = int( (total_gpu_memory * gpu_memory_utilization - peak_memory) // cache_block_size)