diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 14eef00b..219013a3 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -98,7 +98,9 @@ def main( start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, + dtype=torch.float32, + device=device) for _ in range(num_iters): if version == "v1": diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 563e1438..eb216dc8 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -105,7 +105,7 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, k_scale); + k_vec_quant, *k_scale); } } @@ -415,7 +415,7 @@ __device__ void paged_attention_kernel( *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert(v_quant_vec, - v_scale); + *v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -80,6 +80,8 @@ void paged_attention_v1_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_seq_len = @@ -177,8 +179,9 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index a453b224..9935359e 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -37,7 +37,7 @@ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -84,6 +84,8 @@ void paged_attention_v2_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); @@ -188,8 +190,9 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c500..eedad9fa 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -18,15 +18,15 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale); + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, - const double k_scale, const double v_scale); + torch::Tensor& k_scale, torch::Tensor& v_scale); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a95279f..21a0aec0 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel( // block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x, const float k_scale, - const float v_scale) { + const int head_size, const int block_size, const int x, + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( value_cache[tgt_value_idx] = tgt_value; } else { key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } @@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, - const float k_scale, const float v_scale) { + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel( value_cache[tgt_key_value_idx] = tgt_value; } else { key_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } @@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, k_scale, v_scale); + num_heads, head_size, block_size, x, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -268,8 +270,8 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -299,7 +301,9 @@ void reshape_and_cache( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, k_scale, v_scale); + value_stride, num_heads, head_size, block_size, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache_flash( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -308,8 +312,8 @@ void reshape_and_cache_flash( torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { // NOTE(woosuk): In vLLM V1, key.size(0) can be different from // slot_mapping.size(0) because of padding for CUDA graphs. // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index ef5b1408..b9764056 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -460,11 +460,11 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", @@ -782,11 +782,11 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 31d45432..e3809aca 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -107,10 +107,8 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double k_scale, - double v_scale) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); - + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 74e4d818..5d1c5f4c 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -148,7 +148,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/ops.h b/csrc/ops.h index 5a194a0d..34689896 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -34,8 +34,9 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -45,8 +46,9 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 0fec9624..94777906 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale) { + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -406,7 +406,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; const _B8x8 Vlocalb8 = v_ptrh8be[d]; Vlocal[h][b * BLOCK_SIZE / 8 + d] = - scaled_convert_b8x8(Vlocalb8, v_scale); + scaled_convert_b8x8(Vlocalb8, *v_scale_ptr); } } } @@ -416,7 +416,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = - scaled_convert_b8x8(Klocalb8[d], k_scale); + scaled_convert_b8x8(Klocalb8[d], *k_scale_ptr); } } @@ -890,7 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale) { + int max_ctx_blocks, const float* k_scale, const float* v_scale) { UNREACHABLE_CODE } @@ -919,7 +919,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ - k_scale, v_scale); + k_scale_ptr, v_scale_ptr); template @@ -929,7 +929,7 @@ void paged_attention_custom_launcher( torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, const std::optional& alibi_slopes, - float k_scale, float v_scale) { + torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -953,6 +953,8 @@ void paged_attention_custom_launcher( KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_num_partitions = @@ -1087,7 +1089,8 @@ void paged_attention( torch::Tensor& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 34b2f9ce..ba161951 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -10,5 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& context_lens, int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, - double v_scale); + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index a283d426..a5d2e2f9 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -27,7 +27,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index fb53d122..ec63170d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -449,7 +449,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. @@ -459,7 +459,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); diff --git a/docs/source/features/quantization/quantized_kvcache.md b/docs/source/features/quantization/quantized_kvcache.md index 95fa5e81..9f36c294 100644 --- a/docs/source/features/quantization/quantized_kvcache.md +++ b/docs/source/features/quantization/quantized_kvcache.md @@ -35,16 +35,18 @@ Studies have shown that FP8 E4M3 quantization typically only minimally degrades Here is an example of how to enable FP8 quantization: ```python +# To calculate kv cache scales on the fly enable the calculate_kv_scales +# parameter + from vllm import LLM, SamplingParams sampling_params = SamplingParams(temperature=0.7, top_p=0.8) -llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", kv_cache_dtype="fp8") +llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", + kv_cache_dtype="fp8", + calculate_kv_scales=True) prompt = "London is the capital of" out = llm.generate(prompt, sampling_params)[0].outputs[0].text print(out) - -# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial, -# output w/o scaling factors: England, located in the southeastern part of the country. It is known ``` The `kv_cache_dtype` argument specifies the data type for KV cache storage: diff --git a/examples/other/fp8/README.md b/examples/other/fp8/README.md deleted file mode 100644 index 4e8031d9..00000000 --- a/examples/other/fp8/README.md +++ /dev/null @@ -1,96 +0,0 @@ -# FP8 KV Cache - -This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms. - -## Prerequisites - -- Python 3.x -- PyTorch -- NumPy -- Hugging Face Transformers -- Hugging Face Hub -- AMMO - -Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps: -1. Install all necessary prerequisites and dependencies. -2. Convert HF model into a quantized HF model. -3. Extract KV Cache Scaling Factors from quantized HF model. -4. Load KV Cache Scaling Factors into VLLM. - -### 2. Convert HF model into a quantized HF model. -Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md). - -`quantize.py` (examples/other/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format). - -The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/other/fp8/quantizer/README.md`. - -### 3. Extract KV Cache Scaling Factors from quantized HF model. -`extract_scales.py` (examples/other/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following: -1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename. - -2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM. - -3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks. - -```python -# prerequisites: -# - Quantized HF LLaMa 2 model -python3 examples/other/fp8/extract_scales.py --help -Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE] - -KV Scale Extraction Example - -optional arguments: ---quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU). -Optional arguments: ---cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None) ---load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto) ---revision: Specify the model's revision number. (Default: None) ---output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None) ---output_name: Specify the output filename. (Default: kv_cache_scales.json) ---tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None) -``` -```python -Example: -python3 examples/other/fp8/extract_scales.py --quantized_model --tp_size --output_dir -``` -### 4. Load KV Cache Scaling Factors into VLLM. -This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8. -``` -# prerequisites: -# - LLaMa 2 kv_cache_scales.json file - -python3 benchmarks/benchmark_throughput.py --help -usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL] - [--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] - [--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code] - [--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}] - [--quantization-param-path KV_CACHE_quantization_param_path] - -Benchmark Throughput Example -optional arguments: - -h, --help show this help message and exit - --backend {vllm,hf,mii} - --dataset DATASET Path to the dataset. - --input-len INPUT_LEN Input prompt length for each request - --output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset. - --model MODEL - --tokenizer TOKENIZER - --quantization {awq,gptq,None}, -q {awq,gptq,None} - --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE - --n N Number of generated sequences per prompt. - --use-beam-search - --num-prompts NUM_PROMPTS Number of prompts to process. - --seed SEED - --hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend. - --trust-remote-code trust remote code from huggingface - --max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model. - --dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. - --enforce-eager enforce eager execution - --kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria. - --quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria. -``` -Example: -```console -python3 benchmarks/benchmark_throughput.py --input-len --output-len -tp --kv-cache-dtype fp8 --quantization-param-path --model -``` diff --git a/examples/other/fp8/extract_scales.py b/examples/other/fp8/extract_scales.py deleted file mode 100644 index 1dce9d7e..00000000 --- a/examples/other/fp8/extract_scales.py +++ /dev/null @@ -1,367 +0,0 @@ -import argparse -import glob -import json -import os -from typing import Any, Callable, Dict, List, Optional, Tuple - -import numpy as np -import torch -from safetensors.torch import safe_open - -from vllm.model_executor.layers.quantization.schema import QuantParamSchema - - -# Adapted from vllm/model_executor/model_loader/weight_utils.py -# The main differences are that we add the NPZ format and simplify -# its functionality drastically for our purposes (e.g. we assume that -# the quantized model exists locally and there is no need to download it) -def _prepare_hf_weights( - quantized_model_dir: str, - load_format: str = "auto", - fall_back_to_pt: bool = True, -) -> Tuple[List[str], bool]: - if not os.path.isdir(quantized_model_dir): - raise FileNotFoundError( - f"The quantized model directory `{quantized_model_dir}` " - "does not exist.") - use_safetensors = False - # Some quantized models use .pt files for storing the weights. - if load_format == "auto": - allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == "safetensors": - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == "pt": - allow_patterns = ["*.pt"] - elif load_format == "npz": - allow_patterns = ["*.npz"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob( - os.path.join(quantized_model_dir, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True - break - - if not use_safetensors: - # Exclude files that are not needed for inference. - # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 - blacklist = [ - "training_args.bin", - "optimizer.bin", - "optimizer.pt", - "scheduler.pt", - "scaler.pt", - ] - hf_weights_files = [ - f for f in hf_weights_files - if not any(f.endswith(x) for x in blacklist) - ] - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{quantized_model_dir}`") - - return hf_weights_files, use_safetensors - - -# Adapted from vllm/model_executor/model_loader/weight_utils.py -def _hf_tensorfile_iterator(filename: str, load_format: str, - use_safetensors: bool): - if load_format == "npz": - assert not use_safetensors - with np.load(filename) as data: - for name in data.files: - param = torch.from_numpy(data[name]) - yield name, param - elif use_safetensors: - with safe_open(filename, framework="pt") as f: - for name in f.keys(): # NOQA: SIM118 - param = f.get_tensor(name) - yield name, param - else: - state = torch.load(filename, map_location="cpu") - for name, param in state.items(): - yield name, param - del state - torch.cuda.empty_cache() - - -def _kv_scales_extractor( - hf_tensor_files: List[str], - use_safetensors: bool, - rank_keyword: str = "rank", - expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]: - """ - Given a list of files containing tensor data, attempt to extract KV cache - scales from these files. Intended as a helper function taking in the output - from _prepare_hf_weights. - Args: - rank_keyword Matches the number immediately after this keyword in the - tensor filename to determine the TP rank corresponding - to said tensor file - expected_tp_size If specified, the TP size of the tensor files is checked - against this and an error is raised if they don't match. - Returns a dictionary mapping TP ranks to their relevant KV cache scales. - The per-rank scales are themselves represented as a dictionary of layer - indices to the respective per-layer scale. - """ - for char in rank_keyword: - assert not char.isdecimal( - ), f"Rank keyword {rank_keyword} contains a numeric character!" - rank_scales_map: Dict[int, Dict[int, float]] = {} - for tensor_file in hf_tensor_files: - try: - rank_idx = tensor_file.find(rank_keyword) - if rank_idx != -1: - start_idx = rank_idx + len(rank_keyword) - stop_idx = start_idx - while stop_idx < len( - tensor_file) and tensor_file[stop_idx].isdecimal(): - stop_idx += 1 - if stop_idx == start_idx: - raise RuntimeError("Did not find rank # in filename.") - rank = int(tensor_file[start_idx:stop_idx]) - elif len(hf_tensor_files) == 1: - # Since there is only one tensor file, we can assume - # that it's intended for TP rank 0 - rank = 0 - else: - raise RuntimeError( - f"Filename does not contain '{rank_keyword}'.") - except RuntimeError: - print("Unable to determine TP rank " - f"corresponding to file '{tensor_file}'") - raise - - if rank not in rank_scales_map: - layer_scales_map: Dict[int, float] = {} - rank_scales_map[rank] = layer_scales_map - else: - raise RuntimeError( - f"Tensor file '{tensor_file}' shares TP rank {rank} " - "with another tensor file.") - - module_delimiter = ":" if args.load_format == "npz" else "." - for name, param in _hf_tensorfile_iterator(tensor_file, - args.load_format, - use_safetensors): - if "kv_cache_scaling_factor" in name: - nums = [ - int(s) for s in name.split(module_delimiter) - if s.isdecimal() - ] - assert len( - nums) == 1, f"Could not determine layer idx for {name}" - layer_idx = nums[0] - assert layer_idx not in layer_scales_map, f"Duplicate scaling"\ - f" factor corresponding to layer {layer_idx}" - try: - layer_scales_map[layer_idx] = param.item() - except RuntimeError: - print( - "This utility supports only per-tensor scalar scales " - f"for now. The tensor\n {name} = {param} \nis an " - "invalid scale factor.") - raise - - if all( - len(layer_scales_map) == 0 - for layer_scales_map in rank_scales_map.values()): - # Note: this is true even if the rank_scales_map is empty - print("WARNING: No KV cache scale factors found. No output saved.") - return None - empirical_tp_world_size = max(rank_scales_map.keys()) + 1 - if expected_tp_size is not None: - assert expected_tp_size == empirical_tp_world_size, \ - f"User expected TP world size = {expected_tp_size} " \ - "from model but tool is expecting TP world size = " \ - f"{empirical_tp_world_size} from model instead." - for i in range(empirical_tp_world_size): - assert i in rank_scales_map, "Expected TP world size = "\ - f"{empirical_tp_world_size} but did not find KV " \ - f"cache scaling factors for TP rank {i}" - print(f"Found TP world size = {empirical_tp_world_size} " - "when extracting KV cache scales!") - return rank_scales_map - - -def _metadata_extractor(quantized_model_dir: str, - metadata_extract_fns: \ - Dict[str, Callable[[Dict[str, Any]], Any]]) \ - -> Dict[str, Any]: - """ - Given a directory containing quantized model files, this function - aims to extract metadata from the JSON files within this directory. - Each JSON file is expected to represent a dictionary in JSON - format (referred to as a "JSON-dictionary"). Metadata extraction is - defined by a dictionary called metadata_extract_fns, where each - metadata field name is mapped to an extraction function. - - These extraction functions are designed to take a JSON-dictionary - as their only argument and return the corresponding metadata. - While extraction functions are permitted to raise exceptions, they - should only raise a KeyError or ValueError if the metadata field - cannot be extracted from the current JSON-dictionary, yet there's - a possibility of finding it in another JSON-dictionary. - - The function returns a dictionary that maps metadata fields to - their extracted data. The keys of this dictionary correspond exactly - to those in metadata_extract_fns. If any fields fail to be extracted, - their corresponding values are set to None, and a warning is printed. - """ - if not os.path.isdir(quantized_model_dir): - raise FileNotFoundError( - f"The quantized model directory `{quantized_model_dir}` " - "does not exist.") - metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json")) - - result: Dict[str, Any] = {} - for file in metadata_files: - with open(file) as f: - try: - metadata = json.load(f) - except json.JSONDecodeError: - print(f"Could not parse `{file}` as a valid metadata file," - " skipping it.") - continue - if not isinstance(metadata, dict): - print(f"The file `{file}` does not correspond to a " - "JSON-serialized dictionary, skipping it.") - continue - for metadata_name, extract_fn in metadata_extract_fns.items(): - try: - metadata_info = extract_fn(metadata) - if metadata_name not in result: - result[metadata_name] = metadata_info - elif metadata_info != result[metadata_name]: - raise RuntimeError( - "Metadata mismatch! Originally found " - f"{metadata_name} = {result[metadata_name]} but " - f"now found {metadata_name} = {metadata_info} in " - f"`{file}`") - except KeyError: - # It is possible that a given file does not contain some - # of our selected metadata as it could be located in some - # other metadata file. - # 'EFINAE': extract_fn failure is not an error. - pass - except ValueError: - # See above. - pass - - # Warn if we cannot find any of the requested metadata - for metadata_name in metadata_extract_fns: - if metadata_name not in result: - print("WARNING: Unable to find requested metadata field " - f"`{metadata_name}`, setting it to None.") - result[metadata_name] = None - - return result - - -def main(args): - metadata_extract_fns = { - "model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"], - "tp_size": lambda json_dict: int(json_dict["tensor_parallel"]), - "model_dtype": lambda json_dict: json_dict["dtype"] - } - recovered_metadata = _metadata_extractor(args.quantized_model, - metadata_extract_fns) - if args.tp_size is not None: - metadata_tp_size = recovered_metadata["tp_size"] - if metadata_tp_size is not None: - assert args.tp_size == metadata_tp_size, \ - f"User expected TP world size = {args.tp_size} " \ - f"but found TP world size = {metadata_tp_size} from metadata!" - expected_tp_size = args.tp_size or recovered_metadata["tp_size"] - rank_keyword = "rank" - hf_tensor_files, use_safetensors = _prepare_hf_weights( - args.quantized_model, args.load_format) - rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors, - rank_keyword, expected_tp_size) - # Postprocess: formatting to the current schema. Consider pulling it - # out into a dedicated function should it ever become more complicated. - rank_scales_map = { - rank: {k: scale[k] - for k in sorted(scale.keys())} - for rank, scale in rank_scales_map.items() - } - # TODO: Expand this with activation and weights scaling factors when - # they are used in the future - schema = QuantParamSchema( - model_type=recovered_metadata["model_type"], - kv_cache={ - "dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else - recovered_metadata["model_dtype"]), - "scaling_factor": - rank_scales_map - }, - ) - - if args.output_dir is None: - output_file = os.path.join(args.quantized_model, args.output_name) - else: - if not os.path.isdir(args.output_dir): - os.makedirs(args.output_dir, exist_ok=True) - output_file = os.path.join(args.output_dir, args.output_name) - - with open(output_file, 'w') as f: - f.write(schema.model_dump_json(indent=4)) - print(f"Completed! KV cache scaling factors saved to {output_file}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="This simple utility extracts the " - "KV cache scaling factors from a quantized HF model " - "and saves them to a JSON file compatible with later " - "use by vLLM (pass this file to the appropriate " - "runtime typically using the argument " - "--quantization-param-path ). This is only used " - "if the KV cache dtype is FP8 and on ROCm (AMD GPU).") - parser.add_argument( - "--quantized-model", - help="Specify the directory containing a single quantized HF model. " - "It is expected that the quantization format is FP8_E4M3, for use " - "on ROCm (AMD GPU).", - required=True) - parser.add_argument( - "--load_format", - help="Optionally specify the format of the model's tensor files " - "containing the KV cache scaling factors.", - choices=["auto", "safetensors", "npz", "pt"], - default="auto") - parser.add_argument( - "--output-dir", - help="Optionally specify the output directory. By default the " - "KV cache scaling factors will be saved in the model directory, " - "however you can override this behavior here.", - default=None) - parser.add_argument( - "--output-name", - help="Optionally specify the output filename.", - # TODO: Change this once additional scaling factors are enabled - default="kv_cache_scales.json") - parser.add_argument( - "--tp-size", - help="Optionally specify the tensor-parallel (TP) size that the " - "quantized model should correspond to. If specified, during KV " - "cache scaling factor extraction the observed TP size will be " - "checked against this and an error will be raised if there is " - "a mismatch. If not specified, the quantized model's expected " - "TP size is instead inferred from the largest TP rank observed. " - "The expected TP size is cross-checked against the TP ranks " - "observed in the quantized model and an error is raised if any " - "discrepancies are found.", - default=None, - type=int) - args = parser.parse_args() - - main(args) diff --git a/examples/other/fp8/quantizer/README.md b/examples/other/fp8/quantizer/README.md deleted file mode 100644 index d0895e97..00000000 --- a/examples/other/fp8/quantizer/README.md +++ /dev/null @@ -1,32 +0,0 @@ -### Quantizer Utilities -`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported -from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py) - -### Prerequisite - -#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later -`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo` - -#### AMMO Download (code and docs) -`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz` -`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz` - -### Usage - -#### Run on H100 system for speed if FP8; number of GPUs depends on the model size - -#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache: -`python quantize.py --model-dir ./ll2-7b --dtype float16 --qformat fp8 --kv-cache-dtype fp8 --output-dir ./ll2_7b_fp8 --calib-size 512 --tp-size 1` - -Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference) -``` -# ll ./ll2_7b_fp8/ -total 19998244 -drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./ -drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../ --rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json --rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz --rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors -# -``` - diff --git a/examples/other/fp8/quantizer/quantize.py b/examples/other/fp8/quantizer/quantize.py deleted file mode 100644 index d75cc8b3..00000000 --- a/examples/other/fp8/quantizer/quantize.py +++ /dev/null @@ -1,367 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501 -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Adapted from examples/quantization/hf_ptq.py -""" - -import argparse -import copy -import json -import random -import time - -import ammo.torch.quantization as atq -import numpy as np -import torch -from ammo.torch.export import export_model_config -from datasets import load_dataset -from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM, AutoTokenizer - -RAND_SEED = 1234 -MAX_SEQ_LEN = 2048 - -EMPTY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "enable": False, - }, - "*input_quantizer": { - "enable": False - }, - "*lm_head*": { - "enable": False - }, - "*output_layer*": { - "enable": False - }, - "default": { - "enable": False - }, - }, - "algorithm": "max", -} - -KV_CACHE_CFG = { - "*.query_key_value.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.Wqkv.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.W_pack.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.c_attn.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.k_proj.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.v_proj.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, -} - -QUANT_CFG_CHOICES = { - "int8_sq": atq.INT8_SMOOTHQUANT_CFG, - "fp8": atq.FP8_DEFAULT_CFG, - "int4_awq": atq.INT4_AWQ_CFG, - "w4a8_awq": atq.W4A8_AWQ_BETA_CFG, - "int8_wo": EMPTY_CFG, - "int4_wo": EMPTY_CFG, - "full_prec": EMPTY_CFG, -} - -MODEL_NAME_PATTERN_MAP = { - "GPT2": "gpt2", - "Xverse": "llama", - "Llama": "llama", - "Mistral": "llama", - "GPTJ": "gptj", - "FalconForCausalLM": "falcon", - "RWForCausalLM": "falcon", - "baichuan": "baichuan", - "MPT": "mpt", - "Bloom": "bloom", - "ChatGLM": "chatglm", - "QWen": "qwen", -} - - -def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None): - print(f"Initializing tokenizer from {ckpt_path}") - tokenizer = AutoTokenizer.from_pretrained( - ckpt_path, - model_max_length=max_seq_len, - padding_side="left", - trust_remote_code=True, - ) - if model_type and model_type == "qwen": - # qwen use token id 151643 as pad and eos tokens - tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) - tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) - - # can't set attribute 'pad_token' for "" - if tokenizer.pad_token != "": - tokenizer.pad_token = tokenizer.eos_token - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - assert (tokenizer.pad_token - is not None), f"Pad token for {model_type} cannot be set!" - - return tokenizer - - -def get_model(ckpt_path, dtype="fp16", device="cuda"): - print(f"Initializing model from {ckpt_path}") - if dtype == "bf16" or dtype == "bfloat16": - dtype = torch.bfloat16 - elif dtype == "fp16" or dtype == "float16": - dtype = torch.float16 - elif dtype == "fp32" or dtype == "float32": - dtype = torch.float32 - else: - raise NotImplementedError(f"Unknown dtype {dtype}") - - # model_kwargs = {"torch_dtype": dtype} - model_kwargs = {"torch_dtype": "auto"} - - model = AutoModelForCausalLM.from_pretrained(ckpt_path, - device_map="auto", - **model_kwargs, - trust_remote_code=True) - model.eval() - - model_dtype = next(model.parameters()).dtype - if dtype != model_dtype: - print("[TensorRT-LLM][WARNING] The manually set model data type is " - f"{dtype}, but the data type of the HuggingFace model is " - f"{model_dtype}.") - - return model - - -def get_model_type(model): - for k, v in MODEL_NAME_PATTERN_MAP.items(): - if k.lower() in type(model).__name__.lower(): - return v - return None - - -def get_calib_dataloader(data="cnn_dailymail", - tokenizer=None, - batch_size=1, - calib_size=512, - block_size=512, - device=None): - print("Loading calibration dataset") - if data == "pileval": - dataset = load_dataset( - "json", - data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", - split="train") - dataset = dataset["text"][:calib_size] - elif data == "cnn_dailymail": - dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") - dataset = dataset["article"][:calib_size] - else: - raise NotImplementedError - - batch_encoded = tokenizer.batch_encode_plus(dataset, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=block_size) - if device: - batch_encoded = batch_encoded.to(device) - batch_encoded = batch_encoded["input_ids"] - - calib_dataloader = DataLoader(batch_encoded, - batch_size=batch_size, - shuffle=False) - - return calib_dataloader - - -def quantize_model(model, quant_cfg, calib_dataloader=None): - - def calibrate_loop(): - if calib_dataloader is None: - return - """Adjusts weights and scaling factors based on selected algorithms.""" - for idx, data in enumerate(calib_dataloader): - print(f"Calibrating batch {idx}") - model(data) - - print("Starting quantization...") - start_time = time.time() - atq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - end_time = time.time() - print("Quantization done. Total time used: {:.2f} s.".format(end_time - - start_time)) - - return model - - -def main(args): - if not torch.cuda.is_available(): - raise OSError("GPU is required for inference.") - - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - - model = get_model(args.model_dir, args.dtype, args.device) - model_type = get_model_type(model) - tokenizer = get_tokenizer(args.model_dir, model_type=model_type) - - if args.qformat in ["full_prec", "int8_wo", "int4_wo" - ] and args.kv_cache_dtype is None: - print(f"No quantization applied, export {args.dtype} model") - else: - if "awq" in args.qformat: - if args.calib_size > 32: - print("AWQ calibration could take longer with calib_size = " - f"{args.calib_size}, Using calib_size=32 instead") - args.calib_size = 32 - print("\nAWQ calibration could take longer than other calibration " - "methods. Please increase the batch size to speed up the " - "calibration process. Batch size can be set by adding the " - "argument --batch_size to the command line.\n") - - calib_dataloader = get_calib_dataloader( - tokenizer=tokenizer, - batch_size=args.batch_size, - calib_size=args.calib_size, - device=args.device, - ) - - if args.qformat in QUANT_CFG_CHOICES: - quant_cfg = QUANT_CFG_CHOICES[args.qformat] - else: - raise ValueError( - f"Unsupported quantization format: {args.qformat}") - - if "awq" in args.qformat: - quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat]) - weight_quantizer = quant_cfg["quant_cfg"][ - "*weight_quantizer"] # type: ignore - if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] - weight_quantizer["block_sizes"][-1] = args.awq_block_size - - if args.kv_cache_dtype is not None: - if args.kv_cache_dtype == "fp8": - for value in KV_CACHE_CFG.values(): - value.update({"num_bits": (4, 3)}) # type: ignore - quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore - - print(quant_cfg) - - model = quantize_model(model, quant_cfg, calib_dataloader) - - with torch.inference_mode(): - if model_type is None: - print(f"Unknown model type {type(model).__name__}. Continue " - "exporting...") - model_type = f"unknown:{type(model).__name__}" - - export_path = args.output_dir - start_time = time.time() - - if args.qformat == "int4_awq" and model_type == "qwen": - torch.save(model.state_dict(), export_path) - else: - export_npz = (model_type not in [ - 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan' - ]) - - # export safetensors - export_model_config( - model, - model_type, - getattr(torch, args.dtype), - export_dir=export_path, - inference_tensor_parallel=args.tp_size, - inference_pipeline_parallel=args.pp_size, - # export_tensorrt_llm_config=(not export_npz), - export_tensorrt_llm_config=False, - export_npz=export_npz) - - # Workaround for wo quantization - if args.qformat in ["int8_wo", "int4_wo", "full_prec"]: - with open(f"{export_path}/config.json") as f: - tensorrt_llm_config = json.load(f) - if args.qformat == "int8_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' - elif args.qformat == "int4_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16' - else: - tensorrt_llm_config["quantization"]["quant_algo"] = None - with open(f"{export_path}/config.json", "w") as f: - json.dump(tensorrt_llm_config, f, indent=4) - - end_time = time.time() - print("Quantized model exported to {} \nTotal time used {:.2f} s.". - format(export_path, end_time - start_time)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--model-dir", - help="Specify where the HuggingFace model is", - required=True) - parser.add_argument("--device", default="cuda") - parser.add_argument("--dtype", help="Model data type.", default="float16") - parser.add_argument( - "--qformat", - help="Quantization format.", - default="full_prec", - choices=[ - "fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo", - "full_prec" - ], - ) - parser.add_argument("--batch-size", - help="Batch size for calibration.", - type=int, - default=1) - parser.add_argument("--calib-size", - help="Number of samples for calibration.", - type=int, - default=512) - parser.add_argument("--output-dir", default="exported_model") - parser.add_argument("--tp-size", type=int, default=1) - parser.add_argument("--pp-size", type=int, default=1) - parser.add_argument("--awq-block-size", type=int, default=128) - parser.add_argument("--kv-cache-dtype", - help="KV Cache dtype.", - default=None, - choices=["int8", "fp8", None]) - args = parser.parse_args() - - main(args) diff --git a/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json deleted file mode 100644 index a548f0a9..00000000 --- a/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json +++ /dev/null @@ -1,90 +0,0 @@ -{ - "model_type": "llama", - "kv_cache": { - "dtype": "float8_e4m3fn", - "scaling_factor": { - "0": { - "0": 0.0230364128947258, - "1": 0.01979283057153225, - "2": 0.0241350457072258, - "3": 0.0308314748108387, - "4": 0.0430733822286129, - "5": 0.0370396226644516, - "6": 0.0306222103536129, - "7": 0.0357491634786129, - "8": 0.0358189195394516, - "9": 0.0443289652466774, - "10": 0.0433175228536129, - "11": 0.0416782945394516, - "12": 0.0366908498108387, - "13": 0.0432477705180645, - "14": 0.0410505048930645, - "15": 0.0457589291036129, - "16": 0.0418526791036129, - "17": 0.0432477705180645, - "18": 0.0469447560608387, - "19": 0.0514787957072258, - "20": 0.0541294664144516, - "21": 0.0587681382894516, - "22": 0.0625, - "23": 0.0585588738322258, - "24": 0.0600237175822258, - "25": 0.0588030144572258, - "26": 0.0531180277466774, - "27": 0.06396484375, - "28": 0.0603027381002903, - "29": 0.0582101047039032, - "30": 0.0625348836183548, - "31": 0.0585588738322258, - "32": 0.0582798570394516, - "33": 0.0575125589966774, - "34": 0.0590820349752903, - "35": 0.0614188089966774, - "36": 0.0631975457072258, - "37": 0.0615931935608387, - "38": 0.0601283498108387, - "39": 0.0571986623108387, - "40": 0.0670340433716774, - "41": 0.0523507259786129, - "42": 0.0547223798930645, - "43": 0.0631975457072258, - "44": 0.0663713738322258, - "45": 0.0603376142680645, - "46": 0.0652204304933548, - "47": 0.0734514519572258, - "48": 0.0693708211183548, - "49": 0.0725446492433548, - "50": 0.0627790242433548, - "51": 0.0691266804933548, - "52": 0.0688825398683548, - "53": 0.068429134786129, - "54": 0.0605119988322258, - "55": 0.0799386203289032, - "56": 0.0853097140789032, - "57": 0.0661969929933548, - "58": 0.0689871683716774, - "59": 0.0724051371216774, - "60": 0.0541643425822258, - "61": 0.0626743882894516, - "62": 0.0628487765789032, - "63": 0.0607212632894516, - "64": 0.0589076466858387, - "65": 0.0451660193502903, - "66": 0.0453055277466774, - "67": 0.0414341539144516, - "68": 0.0385044664144516, - "69": 0.0414341539144516, - "70": 0.0466308631002903, - "71": 0.0399693101644516, - "72": 0.0437011756002903, - "73": 0.0434221550822258, - "74": 0.0428989976644516, - "75": 0.0401785746216774, - "76": 0.0431082621216774, - "77": 0.0484444759786129, - "78": 0.0417829267680645, - "79": 0.0418178029358387 - } - } - } -} \ No newline at end of file diff --git a/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json deleted file mode 100644 index bb734039..00000000 --- a/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json +++ /dev/null @@ -1,42 +0,0 @@ -{ - "model_type": "llama", - "kv_cache": { - "dtype": "float8_e4m3fn", - "scaling_factor": { - "0": { - "0": 0.0152239128947258, - "1": 0.0188860222697258, - "2": 0.0354178324341774, - "3": 0.0376674123108387, - "4": 0.0418526791036129, - "5": 0.0433175228536129, - "6": 0.0397600457072258, - "7": 0.0424455925822258, - "8": 0.0415387861430645, - "9": 0.0408412404358387, - "10": 0.0395856611430645, - "11": 0.0377371683716774, - "12": 0.0400739423930645, - "13": 0.040771484375, - "14": 0.0393415205180645, - "15": 0.0369001142680645, - "16": 0.03857421875, - "17": 0.0387486070394516, - "18": 0.0403180830180645, - "19": 0.0396205373108387, - "20": 0.0375627800822258, - "21": 0.0407366082072258, - "22": 0.0432477705180645, - "23": 0.0377022884786129, - "24": 0.0399693101644516, - "25": 0.0374581478536129, - "26": 0.0413295216858387, - "27": 0.0442243330180645, - "28": 0.0424804724752903, - "29": 0.0456891767680645, - "30": 0.0409109964966774, - "31": 0.0482352152466774 - } - } - } -} diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 124d5d29..574a0f22 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -182,7 +182,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Call the paged attention kernel. output = torch.empty_like(query) diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index fad342d1..08f31219 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -210,7 +210,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) tp_rank = 0 # Call the paged attention kernel. diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 40550ed5..c848be4f 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -160,7 +160,7 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache, @@ -258,8 +258,8 @@ def test_reshape_and_cache_flash( del key_caches del value_caches - k_scale = key.amax().item() / 256 - v_scale = value.amax().item() / 256 + k_scale = (key.amax() / 256.0).to(torch.float32) + v_scale = (value.amax() / 256.0).to(torch.float32) # Clone the KV caches. if kv_cache_dtype == "fp8": @@ -284,12 +284,12 @@ def test_reshape_and_cache_flash( result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) ops.convert_fp8(result_key_cache, key_cache, - k_scale, + k_scale.item(), kv_dtype=kv_cache_dtype) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) ops.convert_fp8(result_value_cache, value_cache, - v_scale, + v_scale.item(), kv_dtype=kv_cache_dtype) # Run the reference implementation. diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 3fdb7996..10e73ab9 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -138,6 +138,7 @@ def test_contexted_kv_attention( # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time @@ -153,6 +154,8 @@ def test_contexted_kv_attention( b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, sliding_window=sliding_window) torch.cuda.synchronize() start_time = time.time() @@ -168,6 +171,8 @@ def test_contexted_kv_attention( b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, sliding_window=sliding_window) torch.cuda.synchronize() end_time = time.time() @@ -366,6 +371,7 @@ def test_contexted_kv_attention_alibi( # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time @@ -381,6 +387,8 @@ def test_contexted_kv_attention_alibi( b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, alibi_slopes=alibi_slopes) torch.cuda.synchronize() start_time = time.time() @@ -396,6 +404,8 @@ def test_contexted_kv_attention_alibi( b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, alibi_slopes=alibi_slopes) torch.cuda.synchronize() end_time = time.time() diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 848eea7f..80113985 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -909,6 +909,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -958,6 +959,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/tests/models/decoder_only/language/test_fp8.py b/tests/models/decoder_only/language/test_fp8.py index 53f23e24..5f06f1e3 100644 --- a/tests/models/decoder_only/language/test_fp8.py +++ b/tests/models/decoder_only/language/test_fp8.py @@ -19,18 +19,17 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true" @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="fp8 is not supported on this GPU type.") @pytest.mark.parametrize( - "kv_cache_dtype,base_model,test_model,scale_path", + "kv_cache_dtype,base_model,test_model", [ # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors. ("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct", - "nm-testing/Llama-3.2-1B-Instruct-FP8-KV", None), + "nm-testing/Llama-3.2-1B-Instruct-FP8-KV"), # Test FP16 checkpoint w. fp8_e5m2 kv-cache. ("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct", None), + "meta-llama/Llama-3.2-1B-Instruct"), # Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. ("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf", - "meta-llama/Llama-2-7b-chat-hf", - "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") + "meta-llama/Llama-2-7b-chat-hf") ]) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @@ -48,7 +47,6 @@ def test_models( kv_cache_dtype: str, base_model: str, test_model: str, - scale_path: Optional[str], max_tokens: int, enforce_eager: bool, backend: str, @@ -76,10 +74,6 @@ def test_models( baseline_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) - extra_kwargs = {} - if scale_path is not None: - extra_kwargs["quantization_param_path"] = scale_path - with vllm_runner( test_model, max_model_len=MAX_MODEL_LEN, @@ -87,7 +81,6 @@ def test_models( enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, disable_async_output_proc=disable_async_output_proc, - **extra_kwargs, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 309854e6..57f1fd47 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -74,6 +74,7 @@ def test_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), @@ -126,6 +127,7 @@ def test_embedding_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), @@ -177,6 +179,7 @@ def test_multi_step_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) frozen_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d04cbbc0..440bc520 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -48,8 +48,8 @@ def paged_attention_v1( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -80,8 +80,8 @@ def paged_attention_v2( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -112,8 +112,8 @@ def paged_attention_rocm( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, @@ -956,8 +956,8 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, @@ -971,8 +971,8 @@ def reshape_and_cache_flash( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2efe142a..8027a52b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -123,6 +123,10 @@ class AttentionMetadata: multi_modal_placeholder_index_maps: Optional[Dict[ str, MultiModalPlaceholderMap.IndexMap]] + # Enable/disable KV scales calculation. This is so that we can disable the + # calculation until after prefill and cuda graph capture. + enable_kv_scales_calculation: bool + @property @abstractmethod def prefill_metadata(self) -> Optional["AttentionMetadata"]: @@ -226,8 +230,10 @@ class AttentionMetadataBuilder(ABC, Generic[T]): class AttentionLayer(Protocol): - _k_scale: float - _v_scale: float + _k_scale: torch.Tensor + _v_scale: torch.Tensor + _k_scale_float: float + _v_scale_float: float def forward( self, diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9089db11..20e9a3f1 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -222,6 +222,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -251,6 +252,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 18acfb82..1be09928 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -230,6 +230,7 @@ class FlashAttentionMetadata(AttentionMetadata): slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -274,6 +275,7 @@ class FlashAttentionMetadata(AttentionMetadata): num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, @@ -557,6 +559,7 @@ class FlashAttentionMetadataBuilder( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, @@ -675,7 +678,7 @@ class FlashAttentionImpl(AttentionImpl): NOTE: It in-place updates the output tensor. """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert layer._k_scale == 1.0 and layer._v_scale == 1.0, ( + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, ( "key/v_scale is not supported in FlashAttention.") assert output is not None, "Output tensor must be provided." diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b8ffbe6d..3135b0b4 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -219,6 +219,7 @@ class FlashInferState(AttentionState): num_prefills=0, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, @@ -733,6 +734,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, @@ -888,8 +890,8 @@ class FlashInferImpl(AttentionImpl): kv_cache, logits_soft_cap=logits_soft_cap, causal=True, - k_scale=layer._k_scale, - v_scale=layer._v_scale, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, window_left=window_left) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None @@ -899,8 +901,8 @@ class FlashInferImpl(AttentionImpl): kv_cache, sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, - k_scale=layer._k_scale, - v_scale=layer._v_scale, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, window_left=window_left) if prefill_output is None and decode_output is not None: diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index cd729a1c..57916a3c 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): Returns: shape = [num_tokens, num_heads * head_size] """ - assert layer._k_scale == 1.0 and layer._v_scale == 1.0 + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index f5bf390d..facdee6b 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl): Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - assert layer._k_scale == 1.0 and layer._v_scale == 1.0 + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 37860494..82631189 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -140,6 +140,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata): slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_decode_query_len=0, @@ -173,6 +174,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata): num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, @@ -380,6 +382,7 @@ class PlaceholderAttentionMetadataBuilder( num_prefills=self.num_prefills, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e9f2808f..ca6fa9ca 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -153,6 +153,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -182,6 +183,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 8722d737..c3b2398b 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -379,6 +379,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): prefill_block_tables=prefill_block_tables, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) return attn_metadata @@ -454,7 +455,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): Returns: shape = [num_tokens, num_heads * head_size] """ - assert layer._k_scale == 1.0 and layer._v_scale == 1.0 attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3df7f54c..84fe89b7 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -265,6 +265,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -317,6 +318,7 @@ class CommonAttentionState(AttentionState): num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 38e27434..8c25dda7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -218,6 +218,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -262,6 +263,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index c36f8d08..79ea9b66 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import vllm.envs as envs from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config @@ -57,10 +58,12 @@ class Attention(nn.Module): kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size is_attention_free = cache_config.is_attention_free + calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 is_attention_free = False + calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads @@ -70,8 +73,15 @@ class Attention(nn.Module): # expect the pre-quantized k/v_scale to be loaded along # with the model weights. self.kv_cache_dtype = kv_cache_dtype - self._k_scale = 1.0 - self._v_scale = 1.0 + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep the float32 versions of k/v_scale for attention + # backends that don't support tensors (Flashinfer) + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None: @@ -127,6 +137,9 @@ class Attention(nn.Module): ).parallel_config.pipeline_parallel_size) ] + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + def forward( self, query: torch.Tensor, @@ -135,6 +148,9 @@ class Attention(nn.Module): kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: + if self.calculate_kv_scales and \ + attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(key, value) if self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -161,6 +177,14 @@ class Attention(nn.Module): return torch.ops.vllm.unified_attention( query, key, value, self.layer_name) + def calc_kv_scales(self, key, value): + self._k_scale.copy_(torch.abs(key).max() / self.k_range) + self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() + # We only calculate the scales once + self.calculate_kv_scales = False + def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore s += f", num_heads={self.impl.num_heads}" # type: ignore diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index cbc6c74a..3a07184e 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -52,8 +52,8 @@ class _PagedAttention: value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: ops.reshape_and_cache( @@ -80,8 +80,8 @@ class _PagedAttention: num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: tp_rank: int = 0 @@ -149,8 +149,8 @@ class _IPEXPagedAttention(_PagedAttention): value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( @@ -170,8 +170,8 @@ class _IPEXPagedAttention(_PagedAttention): num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: block_size = value_cache.shape[2] diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 076f151f..fd623291 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -69,8 +69,8 @@ class PagedAttention: value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: ops.reshape_and_cache( key, @@ -95,8 +95,8 @@ class PagedAttention: num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -204,8 +204,8 @@ class PagedAttention: max_query_len: int, alibi_slopes: Optional[torch.Tensor], sliding_window: Optional[int], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> torch.Tensor: output = torch.empty_like(query) context_attention_fwd( diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 9c11a8df..e2f2b66d 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -133,7 +133,7 @@ if triton.__version__ >= "2.1.0": other=0.0) # [D,N] if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) else: k = k_load @@ -181,7 +181,7 @@ if triton.__version__ >= "2.1.0": ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) # [N,D] if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: v = v_load p = p.to(v.dtype) @@ -564,7 +564,7 @@ if triton.__version__ >= "2.1.0": other=0.0) # [D,N] if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) else: k = k_load @@ -604,7 +604,7 @@ if triton.__version__ >= "2.1.0": ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: v = v_load p = p.to(v.dtype) @@ -713,8 +713,8 @@ if triton.__version__ >= "2.1.0": b_seq_len, b_ctx_len, max_input_len, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, alibi_slopes=None, sliding_window=None): diff --git a/vllm/config.py b/vllm/config.py index f7547921..efd81ad3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -120,11 +120,6 @@ class ModelConfig: decoding draft models. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. - quantization_param_path: Path to JSON file containing scaling factors. - Used to load KV cache scaling factors into the model when KV cache - type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also - be used to load activation and weight scaling factors when the - model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. @@ -187,7 +182,6 @@ class ModelConfig: factors.append(self.model) factors.append(self.dtype) factors.append(self.quantization) - factors.append(self.quantization_param_path) factors.append(self.revision) factors.append(self.code_revision) factors.append(self.trust_remote_code) @@ -213,7 +207,6 @@ class ModelConfig: max_model_len: Optional[int] = None, spec_target_max_model_len: Optional[int] = None, quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, enforce_eager: Optional[bool] = None, max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20, @@ -274,7 +267,6 @@ class ModelConfig: else: self.tokenizer_revision = tokenizer_revision self.quantization = quantization - self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_seq_len_to_capture = max_seq_len_to_capture self.max_logprobs = max_logprobs @@ -1002,6 +994,7 @@ class CacheConfig: sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, + calculate_kv_scales: Optional[bool] = None, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -1012,7 +1005,7 @@ class CacheConfig: self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb - + self.calculate_kv_scales = calculate_kv_scales self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() @@ -1021,6 +1014,10 @@ class CacheConfig: self.num_gpu_blocks: Optional[int] = None self.num_cpu_blocks: Optional[int] = None + # Set calculate_kv_scales to False if the value is unset. + if self.calculate_kv_scales is None: + self.calculate_kv_scales = False + def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info @@ -3297,7 +3294,6 @@ class VllmConfig: f"quantization={self.model_config.quantization}, " f"enforce_eager={self.model_config.enforce_eager}, " f"kv_cache_dtype={self.cache_config.cache_dtype}, " - f"quantization_param_path={self.model_config.quantization_param_path}," f" device_config={self.device_config.device}, " f"decoding_config={self.decoding_config!r}, " f"observability_config={self.observability_config!r}, " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f58c1b55..5d3aeb68 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -98,7 +98,6 @@ class EngineArgs: config_format: ConfigFormat = ConfigFormat.AUTO dtype: str = 'auto' kv_cache_dtype: str = 'auto' - quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -199,6 +198,8 @@ class EngineArgs: generation_config: Optional[str] = None enable_sleep_mode: bool = False + calculate_kv_scales: Optional[bool] = None + def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model @@ -350,17 +351,6 @@ class EngineArgs: help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=nullable_str, - default=None, - help='Path to the JSON file containing the KV cache ' - 'scaling factors. This should generally be supplied, when ' - 'KV cache dtype is FP8. Otherwise, KV cache scaling factors ' - 'default to 1.0, which may cause accuracy issues. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version ' - 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria.') parser.add_argument('--max-model-len', type=int, default=EngineArgs.max_model_len, @@ -962,6 +952,15 @@ class EngineArgs: help="Enable sleep mode for the engine. " "(only cuda platform is supported)") + parser.add_argument( + '--calculate-kv-scales', + action='store_true', + help='This enables dynamic calculation of ' + 'k_scale and v_scale when kv-cache-dtype is fp8. ' + 'If calculate-kv-scales is false, the scales will ' + 'be loaded from the model checkpoint if available. ' + 'Otherwise, the scales will default to 1.0.') + return parser @classmethod @@ -991,7 +990,6 @@ class EngineArgs: tokenizer_revision=self.tokenizer_revision, max_model_len=self.max_model_len, quantization=self.quantization, - quantization_param_path=self.quantization_param_path, enforce_eager=self.enforce_eager, max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, @@ -1068,6 +1066,7 @@ class EngineArgs: sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, + calculate_kv_scales=self.calculate_kv_scales, ) parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, diff --git a/vllm/envs.py b/vllm/envs.py index b72e9141..8627caec 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -73,6 +73,8 @@ if TYPE_CHECKING: VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False + K_SCALE_CONSTANT: int = 200 + V_SCALE_CONSTANT: int = 100 VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 @@ -474,6 +476,13 @@ environment_variables: Dict[str, Callable[[], Any]] = { "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + # Divisor for dynamic key scale factor calculation for FP8 KV Cache + "K_SCALE_CONSTANT": + lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + + # Divisor for dynamic value scale factor calculation for FP8 KV Cache + "V_SCALE_CONSTANT": + lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index a74f5415..e1870c73 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -3,6 +3,7 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -40,11 +41,16 @@ class BaseKVCacheMethod(QuantizeMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. - if layer.kv_cache_dtype != "auto": + # No need to process kv scales after loading if we are going to + # calculate them on the fly. + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: if layer.k_scale > 0.0 and layer.v_scale > 0.0: # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_rocm(): + k_scale *= 2 + v_scale *= 2 elif layer.k_scale < 0.0 and layer.v_scale < 0.0: # If no scales were loaded (both scales are invalid negative # values), use the default value of 1.0 @@ -58,6 +64,9 @@ class BaseKVCacheMethod(QuantizeMethodBase): scale_to_duplicate = max(layer.k_scale, layer.v_scale) k_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_rocm(): + k_scale *= 2 + v_scale *= 2 if not isinstance(k_scale, float) or not isinstance( v_scale, float): @@ -65,9 +74,11 @@ class BaseKVCacheMethod(QuantizeMethodBase): "for fp8 KV cache") # These are used in the final Attention.forward() - layer._k_scale = k_scale - layer._v_scale = v_scale - if (layer._k_scale == 1.0 and layer._v_scale == 1.0 + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + if (k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype): logger.warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 9cfcdbf6..b7040722 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -6,8 +6,7 @@ import json import os import tempfile from collections import defaultdict -from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, - Tuple, Union) +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import filelock import gguf @@ -23,7 +22,6 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) -from vllm.model_executor.layers.quantization.schema import QuantParamSchema from vllm.platforms import current_platform from vllm.utils import PlaceholderModule @@ -496,47 +494,6 @@ def gguf_quant_weights_iterator( yield name, param -def kv_cache_scales_loader( - filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, - model_type: Optional[str]) -> Iterable[Tuple[int, float]]: - """ - A simple utility to read in KV cache scaling factors that have been - previously serialized to disk. Used by the model to populate the appropriate - KV cache scaling factors. The serialization should represent a dictionary - whose keys are the TP ranks and values are another dictionary mapping layers - to their KV cache scaling factors. - Keep this function in sync with the output of - examples/other/fp8/extract_scales.py - """ - try: - with open(filename) as f: - context = { - "model_type": model_type, - "num_hidden_layers": num_hidden_layers, - "tp_rank": tp_rank, - "tp_size": tp_size, - } - schema_dct = json.load(f) - schema = QuantParamSchema.model_validate(schema_dct, - context=context) - layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] - return layer_scales_map.items() - - except FileNotFoundError: - logger.error("File or directory '%s' not found.", filename) - except json.JSONDecodeError: - logger.error("Error decoding JSON in file '%s'.", filename) - except Exception: - logger.exception("An error occurred while reading '%s'.", filename) - # This section is reached if and only if any of the excepts are hit - # Return an empty iterable (list) => no KV cache scales are loaded - # which ultimately defaults to 1.0 scales - logger.warning( - "Defaulting to KV cache scaling factors = 1.0 for all " - "layers in TP rank %d as an error occurred during loading.", tp_rank) - return [] - - def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: """convert PySafeSlice object from safetensors to torch.Tensor diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index eab3bf07..bc3295da 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -30,8 +30,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.exaone import ExaoneConfig @@ -576,32 +574,3 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - tp_rank, - tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type, - ): - if not isinstance(self.transformer.h[layer_idx], nn.Identity): - layer_self_attn = self.transformer.h[layer_idx].attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn.attn, "_k_scale"): - layer_self_attn.attn._k_scale = scaling_factor - layer_self_attn.attn._v_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index ddd2d7a1..543b4e2f 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -29,8 +29,7 @@ from transformers import GraniteConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -518,29 +516,3 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, tp_rank, tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type): - if not isinstance(self.model.layers[layer_idx], nn.Identity): - layer_self_attn = self.model.layers[layer_idx].self_attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn.attn, "_k_scale"): - layer_self_attn.attn._k_scale = scaling_factor - layer_self_attn.attn._v_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a5bd4188..e214c30f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,8 +29,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -43,9 +42,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -440,32 +438,6 @@ class LlamaModel(nn.Module): loaded_params.add(name) return loaded_params - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, tp_rank, tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type): - if not isinstance(self.layers[layer_idx], nn.Identity): - layer_self_attn = self.layers[layer_idx].self_attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn.attn, "_k_scale"): - layer_self_attn.attn._k_scale = scaling_factor - layer_self_attn.attn._v_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") - class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -593,9 +565,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.maybe_remap_mistral(name, loaded_weight) for name, loaded_weight in weights) - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - self.model.load_kv_cache_scales(quantization_param_path) - # This function is used to remap the mistral format as # used by Mistral and Llama <=2 def maybe_remap_mistral( diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 25542816..61baa8e5 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -831,6 +831,7 @@ class MllamaTextCrossAttention(nn.Module): ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: + i = torch.ones(1, dtype=torch.float32) if self.attn.backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1): cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) @@ -843,8 +844,8 @@ class MllamaTextCrossAttention(nn.Module): attn_metadata. cross_slot_mapping, # type: ignore[union-attr] "auto", - 1.0, - 1.0, + i, + i, ) elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): key_cache, value_cache = PagedAttention.split_kv_cache( @@ -853,7 +854,7 @@ class MllamaTextCrossAttention(nn.Module): cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) PagedAttention.write_to_paged_cache( cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + attn_metadata.cross_slot_mapping, "auto", i, i) else: raise ValueError( f"Unsupported Attention backend {self.attn.backend} " diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 37c5a4b5..e6d919f2 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -30,8 +30,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -535,32 +533,3 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - tp_rank, - tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type, - ): - if not isinstance(self.model.layers[layer_idx], nn.Identity): - layer_self_attn = self.model.layers[layer_idx].self_attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn.attn, "_k_scale"): - layer_self_attn.attn._k_scale = scaling_factor - layer_self_attn.attn._v_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1806fec8..7fe9b3a8 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -166,10 +166,6 @@ class FlashAttentionImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert layer._k_scale == 1.0 and layer._v_scale == 1.0, ( - "key/v_scale is not supported in FlashAttention.") - assert output is not None, "Output tensor must be provided." if attn_metadata is None: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 4c8f69e4..a339c97a 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -903,7 +903,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps= - None # FIXME(kzawora): mutli-modality will not work here + None, # FIXME(kzawora): mutli-modality will not work here + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) @@ -1057,7 +1058,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fe504821..cf2f1c6b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -3,7 +3,6 @@ import gc import inspect import itertools import time -import warnings import weakref from contextlib import contextmanager from dataclasses import dataclass @@ -41,7 +40,6 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) -from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( @@ -1151,34 +1149,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.prompt_adapter_manager.create_prompt_adapter_manager( self.model)) - if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm() - or current_platform.is_cuda()): - # Currently only ROCm accepts kv-cache scaling factors - # via quantization_param_path and this will be deprecated - # in the future. - if self.model_config.quantization_param_path is not None: - if callable(getattr(self.model, "load_kv_cache_scales", None)): - warnings.warn( - "Loading kv cache scaling factor from JSON is " - "deprecated and will be removed. Please include " - "kv cache scaling factors in the model checkpoint.", - FutureWarning, - stacklevel=2) - self.model.load_kv_cache_scales( - self.model_config.quantization_param_path) - logger.info("Loaded KV cache scaling factors from %s", - self.model_config.quantization_param_path) - else: - raise RuntimeError( - "Using FP8 KV cache and scaling factors provided but " - "model %s does not support loading scaling factors.", - self.model.__class__) - else: - logger.warning( - "Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") - if self.vllm_config.compilation_config.level ==\ CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): backend = self.vllm_config.compilation_config.init_backend( @@ -1366,6 +1336,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): dtype=self.model_config.dtype, device=self.device) + # Disable KV Scale Calculation for dummy data during profile run + if model_input.attn_metadata is not None: + model_input.attn_metadata.enable_kv_scales_calculation = False + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return @@ -1510,7 +1484,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): batch_size, is_encoder_decoder_model=self.model_config. is_encoder_decoder)) - + # Disable KV Scale Calculation for graph capture + attn_metadata.enable_kv_scales_calculation = False if self.lora_config: lora_mapping = LoRAMapping( **dict(index_mapping=[0] * batch_size, diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 9d0a759c..42fe2cf6 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -282,6 +282,7 @@ class OpenVINOModelRunner(ModelRunnerBase): block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index f5c7bc95..a3f648f4 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -190,6 +190,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=None, context_lens=None, effective_query_lens=None, @@ -208,6 +209,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=effective_query_lens, @@ -239,6 +241,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) @@ -425,6 +428,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=prompt_lens, @@ -496,6 +500,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): num_decode_tokens=batch_size, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 053658d0..b7b7b722 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -261,6 +261,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): is_prompt=True, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, @@ -345,6 +346,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): is_prompt=False, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=torch.tensor([]), max_seqlen=0,