[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
Gregory Shtrasberg 2025-01-23 13:04:03 -05:00 committed by GitHub
parent 6e650f56a1
commit e97f802b2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 276 additions and 1365 deletions

View File

@ -98,7 +98,9 @@ def main(
start_time = time.perf_counter()
# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = torch.tensor(1.0,
dtype=torch.float32,
device=device)
for _ in range(num_iters):
if version == "v1":

View File

@ -105,7 +105,7 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y;
@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale);
k_vec_quant, *k_scale);
}
}
@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale);
*v_scale);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
@ -549,7 +549,7 @@ __global__ void paged_attention_v2_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,

View File

@ -41,7 +41,7 @@
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
@ -53,10 +53,10 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -80,6 +80,8 @@ void paged_attention_v1_launcher(
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
@ -177,8 +179,9 @@ void paged_attention_v1(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);

View File

@ -37,7 +37,7 @@
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
@ -54,10 +54,10 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -84,6 +84,8 @@ void paged_attention_v2_launcher(
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
@ -188,8 +190,9 @@ void paged_attention_v2(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);

View File

@ -18,15 +18,15 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale);
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);
torch::Tensor& k_scale, torch::Tensor& v_scale);
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,

View File

@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x, const float k_scale,
const float v_scale) {
const int head_size, const int block_size, const int x,
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
}
}
@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const float k_scale, const float v_scale) {
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel(
value_cache[tgt_key_value_idx] = tgt_value;
} else {
key_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
}
}
@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, k_scale, v_scale);
num_heads, head_size, block_size, x, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
@ -268,8 +270,8 @@ void reshape_and_cache(
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
@ -299,7 +301,9 @@ void reshape_and_cache(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
value_stride, num_heads, head_size, block_size, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
@ -308,8 +312,8 @@ void reshape_and_cache_flash(
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because

View File

@ -460,11 +460,11 @@ void paged_attention_v1(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
@ -782,11 +782,11 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",

View File

@ -107,10 +107,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, double k_scale,
double v_scale) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);

View File

@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
@ -148,7 +148,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
}

View File

@ -34,8 +34,9 @@ void paged_attention_v1(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
@ -45,8 +46,9 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);

View File

@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) {
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
@ -406,7 +406,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const _B8x8 Vlocalb8 = v_ptrh8be[d];
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, *v_scale_ptr);
}
}
}
@ -416,7 +416,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
Klocal[d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], *k_scale_ptr);
}
}
@ -890,7 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
@ -919,7 +919,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale, v_scale);
k_scale_ptr, v_scale_ptr);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
@ -929,7 +929,7 @@ void paged_attention_custom_launcher(
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
float k_scale, float v_scale) {
torch::Tensor& k_scale, torch::Tensor& v_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -953,6 +953,8 @@ void paged_attention_custom_launcher(
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
const int max_num_partitions =
@ -1087,7 +1089,8 @@ void paged_attention(
torch::Tensor& context_lens, // [num_seqs]
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Half) {

View File

@ -10,5 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& context_lens, int64_t block_size,
int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale,
double v_scale);
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);

View File

@ -27,7 +27,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
" Tensor k_scale, Tensor v_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}

View File

@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
@ -449,7 +449,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key and value tensors and cache them.
@ -459,7 +459,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);

View File

@ -35,16 +35,18 @@ Studies have shown that FP8 E4M3 quantization typically only minimally degrades
Here is an example of how to enable FP8 quantization:
```python
# To calculate kv cache scales on the fly enable the calculate_kv_scales
# parameter
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", kv_cache_dtype="fp8")
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
calculate_kv_scales=True)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
```
The `kv_cache_dtype` argument specifies the data type for KV cache storage:

View File

@ -1,96 +0,0 @@
# FP8 KV Cache
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms.
## Prerequisites
- Python 3.x
- PyTorch
- NumPy
- Hugging Face Transformers
- Hugging Face Hub
- AMMO
Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps:
1. Install all necessary prerequisites and dependencies.
2. Convert HF model into a quantized HF model.
3. Extract KV Cache Scaling Factors from quantized HF model.
4. Load KV Cache Scaling Factors into VLLM.
### 2. Convert HF model into a quantized HF model.
Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md).
`quantize.py` (examples/other/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format).
The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/other/fp8/quantizer/README.md`.
### 3. Extract KV Cache Scaling Factors from quantized HF model.
`extract_scales.py` (examples/other/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following:
1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename.
2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM.
3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks.
```python
# prerequisites:
# - Quantized HF LLaMa 2 model
python3 examples/other/fp8/extract_scales.py --help
Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE]
KV Scale Extraction Example
optional arguments:
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU).
Optional arguments:
--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None)
--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto)
--revision: Specify the model's revision number. (Default: None)
--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None)
--output_name: Specify the output filename. (Default: kv_cache_scales.json)
--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None)
```
```python
Example:
python3 examples/other/fp8/extract_scales.py --quantized_model <QUANTIZED_MODEL_DIR> --tp_size <TENSOR_PARALLEL_SIZE> --output_dir <PATH_TO_OUTPUT_DIR>
```
### 4. Load KV Cache Scaling Factors into VLLM.
This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8.
```
# prerequisites:
# - LLaMa 2 kv_cache_scales.json file
python3 benchmarks/benchmark_throughput.py --help
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
[--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
[--quantization-param-path KV_CACHE_quantization_param_path]
Benchmark Throughput Example
optional arguments:
-h, --help show this help message and exit
--backend {vllm,hf,mii}
--dataset DATASET Path to the dataset.
--input-len INPUT_LEN Input prompt length for each request
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
--model MODEL
--tokenizer TOKENIZER
--quantization {awq,gptq,None}, -q {awq,gptq,None}
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
--n N Number of generated sequences per prompt.
--use-beam-search
--num-prompts NUM_PROMPTS Number of prompts to process.
--seed SEED
--hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend.
--trust-remote-code trust remote code from huggingface
--max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model.
--dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
--enforce-eager enforce eager execution
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria.
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
```
Example:
```console
python3 benchmarks/benchmark_throughput.py --input-len <INPUT_LEN> --output-len <OUTPUT_LEN> -tp <TENSOR_PARALLEL_SIZE> --kv-cache-dtype fp8 --quantization-param-path <path/to/kv_cache_scales.json> --model <path-to-llama2>
```

View File

@ -1,367 +0,0 @@
import argparse
import glob
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from safetensors.torch import safe_open
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
# Adapted from vllm/model_executor/model_loader/weight_utils.py
# The main differences are that we add the NPZ format and simplify
# its functionality drastically for our purposes (e.g. we assume that
# the quantized model exists locally and there is no need to download it)
def _prepare_hf_weights(
quantized_model_dir: str,
load_format: str = "auto",
fall_back_to_pt: bool = True,
) -> Tuple[List[str], bool]:
if not os.path.isdir(quantized_model_dir):
raise FileNotFoundError(
f"The quantized model directory `{quantized_model_dir}` "
"does not exist.")
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npz":
allow_patterns = ["*.npz"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(
os.path.join(quantized_model_dir, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{quantized_model_dir}`")
return hf_weights_files, use_safetensors
# Adapted from vllm/model_executor/model_loader/weight_utils.py
def _hf_tensorfile_iterator(filename: str, load_format: str,
use_safetensors: bool):
if load_format == "npz":
assert not use_safetensors
with np.load(filename) as data:
for name in data.files:
param = torch.from_numpy(data[name])
yield name, param
elif use_safetensors:
with safe_open(filename, framework="pt") as f:
for name in f.keys(): # NOQA: SIM118
param = f.get_tensor(name)
yield name, param
else:
state = torch.load(filename, map_location="cpu")
for name, param in state.items():
yield name, param
del state
torch.cuda.empty_cache()
def _kv_scales_extractor(
hf_tensor_files: List[str],
use_safetensors: bool,
rank_keyword: str = "rank",
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
"""
Given a list of files containing tensor data, attempt to extract KV cache
scales from these files. Intended as a helper function taking in the output
from _prepare_hf_weights.
Args:
rank_keyword Matches the number immediately after this keyword in the
tensor filename to determine the TP rank corresponding
to said tensor file
expected_tp_size If specified, the TP size of the tensor files is checked
against this and an error is raised if they don't match.
Returns a dictionary mapping TP ranks to their relevant KV cache scales.
The per-rank scales are themselves represented as a dictionary of layer
indices to the respective per-layer scale.
"""
for char in rank_keyword:
assert not char.isdecimal(
), f"Rank keyword {rank_keyword} contains a numeric character!"
rank_scales_map: Dict[int, Dict[int, float]] = {}
for tensor_file in hf_tensor_files:
try:
rank_idx = tensor_file.find(rank_keyword)
if rank_idx != -1:
start_idx = rank_idx + len(rank_keyword)
stop_idx = start_idx
while stop_idx < len(
tensor_file) and tensor_file[stop_idx].isdecimal():
stop_idx += 1
if stop_idx == start_idx:
raise RuntimeError("Did not find rank # in filename.")
rank = int(tensor_file[start_idx:stop_idx])
elif len(hf_tensor_files) == 1:
# Since there is only one tensor file, we can assume
# that it's intended for TP rank 0
rank = 0
else:
raise RuntimeError(
f"Filename does not contain '{rank_keyword}'.")
except RuntimeError:
print("Unable to determine TP rank "
f"corresponding to file '{tensor_file}'")
raise
if rank not in rank_scales_map:
layer_scales_map: Dict[int, float] = {}
rank_scales_map[rank] = layer_scales_map
else:
raise RuntimeError(
f"Tensor file '{tensor_file}' shares TP rank {rank} "
"with another tensor file.")
module_delimiter = ":" if args.load_format == "npz" else "."
for name, param in _hf_tensorfile_iterator(tensor_file,
args.load_format,
use_safetensors):
if "kv_cache_scaling_factor" in name:
nums = [
int(s) for s in name.split(module_delimiter)
if s.isdecimal()
]
assert len(
nums) == 1, f"Could not determine layer idx for {name}"
layer_idx = nums[0]
assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
f" factor corresponding to layer {layer_idx}"
try:
layer_scales_map[layer_idx] = param.item()
except RuntimeError:
print(
"This utility supports only per-tensor scalar scales "
f"for now. The tensor\n {name} = {param} \nis an "
"invalid scale factor.")
raise
if all(
len(layer_scales_map) == 0
for layer_scales_map in rank_scales_map.values()):
# Note: this is true even if the rank_scales_map is empty
print("WARNING: No KV cache scale factors found. No output saved.")
return None
empirical_tp_world_size = max(rank_scales_map.keys()) + 1
if expected_tp_size is not None:
assert expected_tp_size == empirical_tp_world_size, \
f"User expected TP world size = {expected_tp_size} " \
"from model but tool is expecting TP world size = " \
f"{empirical_tp_world_size} from model instead."
for i in range(empirical_tp_world_size):
assert i in rank_scales_map, "Expected TP world size = "\
f"{empirical_tp_world_size} but did not find KV " \
f"cache scaling factors for TP rank {i}"
print(f"Found TP world size = {empirical_tp_world_size} "
"when extracting KV cache scales!")
return rank_scales_map
def _metadata_extractor(quantized_model_dir: str,
metadata_extract_fns: \
Dict[str, Callable[[Dict[str, Any]], Any]]) \
-> Dict[str, Any]:
"""
Given a directory containing quantized model files, this function
aims to extract metadata from the JSON files within this directory.
Each JSON file is expected to represent a dictionary in JSON
format (referred to as a "JSON-dictionary"). Metadata extraction is
defined by a dictionary called metadata_extract_fns, where each
metadata field name is mapped to an extraction function.
These extraction functions are designed to take a JSON-dictionary
as their only argument and return the corresponding metadata.
While extraction functions are permitted to raise exceptions, they
should only raise a KeyError or ValueError if the metadata field
cannot be extracted from the current JSON-dictionary, yet there's
a possibility of finding it in another JSON-dictionary.
The function returns a dictionary that maps metadata fields to
their extracted data. The keys of this dictionary correspond exactly
to those in metadata_extract_fns. If any fields fail to be extracted,
their corresponding values are set to None, and a warning is printed.
"""
if not os.path.isdir(quantized_model_dir):
raise FileNotFoundError(
f"The quantized model directory `{quantized_model_dir}` "
"does not exist.")
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
result: Dict[str, Any] = {}
for file in metadata_files:
with open(file) as f:
try:
metadata = json.load(f)
except json.JSONDecodeError:
print(f"Could not parse `{file}` as a valid metadata file,"
" skipping it.")
continue
if not isinstance(metadata, dict):
print(f"The file `{file}` does not correspond to a "
"JSON-serialized dictionary, skipping it.")
continue
for metadata_name, extract_fn in metadata_extract_fns.items():
try:
metadata_info = extract_fn(metadata)
if metadata_name not in result:
result[metadata_name] = metadata_info
elif metadata_info != result[metadata_name]:
raise RuntimeError(
"Metadata mismatch! Originally found "
f"{metadata_name} = {result[metadata_name]} but "
f"now found {metadata_name} = {metadata_info} in "
f"`{file}`")
except KeyError:
# It is possible that a given file does not contain some
# of our selected metadata as it could be located in some
# other metadata file.
# 'EFINAE': extract_fn failure is not an error.
pass
except ValueError:
# See above.
pass
# Warn if we cannot find any of the requested metadata
for metadata_name in metadata_extract_fns:
if metadata_name not in result:
print("WARNING: Unable to find requested metadata field "
f"`{metadata_name}`, setting it to None.")
result[metadata_name] = None
return result
def main(args):
metadata_extract_fns = {
"model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
"tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
"model_dtype": lambda json_dict: json_dict["dtype"]
}
recovered_metadata = _metadata_extractor(args.quantized_model,
metadata_extract_fns)
if args.tp_size is not None:
metadata_tp_size = recovered_metadata["tp_size"]
if metadata_tp_size is not None:
assert args.tp_size == metadata_tp_size, \
f"User expected TP world size = {args.tp_size} " \
f"but found TP world size = {metadata_tp_size} from metadata!"
expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
rank_keyword = "rank"
hf_tensor_files, use_safetensors = _prepare_hf_weights(
args.quantized_model, args.load_format)
rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
rank_keyword, expected_tp_size)
# Postprocess: formatting to the current schema. Consider pulling it
# out into a dedicated function should it ever become more complicated.
rank_scales_map = {
rank: {k: scale[k]
for k in sorted(scale.keys())}
for rank, scale in rank_scales_map.items()
}
# TODO: Expand this with activation and weights scaling factors when
# they are used in the future
schema = QuantParamSchema(
model_type=recovered_metadata["model_type"],
kv_cache={
"dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
recovered_metadata["model_dtype"]),
"scaling_factor":
rank_scales_map
},
)
if args.output_dir is None:
output_file = os.path.join(args.quantized_model, args.output_name)
else:
if not os.path.isdir(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
output_file = os.path.join(args.output_dir, args.output_name)
with open(output_file, 'w') as f:
f.write(schema.model_dump_json(indent=4))
print(f"Completed! KV cache scaling factors saved to {output_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="This simple utility extracts the "
"KV cache scaling factors from a quantized HF model "
"and saves them to a JSON file compatible with later "
"use by vLLM (pass this file to the appropriate "
"runtime typically using the argument "
"--quantization-param-path <filename>). This is only used "
"if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
parser.add_argument(
"--quantized-model",
help="Specify the directory containing a single quantized HF model. "
"It is expected that the quantization format is FP8_E4M3, for use "
"on ROCm (AMD GPU).",
required=True)
parser.add_argument(
"--load_format",
help="Optionally specify the format of the model's tensor files "
"containing the KV cache scaling factors.",
choices=["auto", "safetensors", "npz", "pt"],
default="auto")
parser.add_argument(
"--output-dir",
help="Optionally specify the output directory. By default the "
"KV cache scaling factors will be saved in the model directory, "
"however you can override this behavior here.",
default=None)
parser.add_argument(
"--output-name",
help="Optionally specify the output filename.",
# TODO: Change this once additional scaling factors are enabled
default="kv_cache_scales.json")
parser.add_argument(
"--tp-size",
help="Optionally specify the tensor-parallel (TP) size that the "
"quantized model should correspond to. If specified, during KV "
"cache scaling factor extraction the observed TP size will be "
"checked against this and an error will be raised if there is "
"a mismatch. If not specified, the quantized model's expected "
"TP size is instead inferred from the largest TP rank observed. "
"The expected TP size is cross-checked against the TP ranks "
"observed in the quantized model and an error is raised if any "
"discrepancies are found.",
default=None,
type=int)
args = parser.parse_args()
main(args)

View File

@ -1,32 +0,0 @@
### Quantizer Utilities
`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported
from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py)
### Prerequisite
#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later
`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo`
#### AMMO Download (code and docs)
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz`
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz`
### Usage
#### Run on H100 system for speed if FP8; number of GPUs depends on the model size
#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache:
`python quantize.py --model-dir ./ll2-7b --dtype float16 --qformat fp8 --kv-cache-dtype fp8 --output-dir ./ll2_7b_fp8 --calib-size 512 --tp-size 1`
Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference)
```
# ll ./ll2_7b_fp8/
total 19998244
drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./
drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../
-rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json
-rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz
-rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors
#
```

View File

@ -1,367 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from examples/quantization/hf_ptq.py
"""
import argparse
import copy
import json
import random
import time
import ammo.torch.quantization as atq
import numpy as np
import torch
from ammo.torch.export import export_model_config
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
RAND_SEED = 1234
MAX_SEQ_LEN = 2048
EMPTY_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"enable": False,
},
"*input_quantizer": {
"enable": False
},
"*lm_head*": {
"enable": False
},
"*output_layer*": {
"enable": False
},
"default": {
"enable": False
},
},
"algorithm": "max",
}
KV_CACHE_CFG = {
"*.query_key_value.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.Wqkv.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.W_pack.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.c_attn.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.k_proj.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.v_proj.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
}
QUANT_CFG_CHOICES = {
"int8_sq": atq.INT8_SMOOTHQUANT_CFG,
"fp8": atq.FP8_DEFAULT_CFG,
"int4_awq": atq.INT4_AWQ_CFG,
"w4a8_awq": atq.W4A8_AWQ_BETA_CFG,
"int8_wo": EMPTY_CFG,
"int4_wo": EMPTY_CFG,
"full_prec": EMPTY_CFG,
}
MODEL_NAME_PATTERN_MAP = {
"GPT2": "gpt2",
"Xverse": "llama",
"Llama": "llama",
"Mistral": "llama",
"GPTJ": "gptj",
"FalconForCausalLM": "falcon",
"RWForCausalLM": "falcon",
"baichuan": "baichuan",
"MPT": "mpt",
"Bloom": "bloom",
"ChatGLM": "chatglm",
"QWen": "qwen",
}
def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
print(f"Initializing tokenizer from {ckpt_path}")
tokenizer = AutoTokenizer.from_pretrained(
ckpt_path,
model_max_length=max_seq_len,
padding_side="left",
trust_remote_code=True,
)
if model_type and model_type == "qwen":
# qwen use token id 151643 as pad and eos tokens
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
# can't set attribute 'pad_token' for "<unk>"
if tokenizer.pad_token != "<unk>":
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
assert (tokenizer.pad_token
is not None), f"Pad token for {model_type} cannot be set!"
return tokenizer
def get_model(ckpt_path, dtype="fp16", device="cuda"):
print(f"Initializing model from {ckpt_path}")
if dtype == "bf16" or dtype == "bfloat16":
dtype = torch.bfloat16
elif dtype == "fp16" or dtype == "float16":
dtype = torch.float16
elif dtype == "fp32" or dtype == "float32":
dtype = torch.float32
else:
raise NotImplementedError(f"Unknown dtype {dtype}")
# model_kwargs = {"torch_dtype": dtype}
model_kwargs = {"torch_dtype": "auto"}
model = AutoModelForCausalLM.from_pretrained(ckpt_path,
device_map="auto",
**model_kwargs,
trust_remote_code=True)
model.eval()
model_dtype = next(model.parameters()).dtype
if dtype != model_dtype:
print("[TensorRT-LLM][WARNING] The manually set model data type is "
f"{dtype}, but the data type of the HuggingFace model is "
f"{model_dtype}.")
return model
def get_model_type(model):
for k, v in MODEL_NAME_PATTERN_MAP.items():
if k.lower() in type(model).__name__.lower():
return v
return None
def get_calib_dataloader(data="cnn_dailymail",
tokenizer=None,
batch_size=1,
calib_size=512,
block_size=512,
device=None):
print("Loading calibration dataset")
if data == "pileval":
dataset = load_dataset(
"json",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
dataset = dataset["text"][:calib_size]
elif data == "cnn_dailymail":
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
dataset = dataset["article"][:calib_size]
else:
raise NotImplementedError
batch_encoded = tokenizer.batch_encode_plus(dataset,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=block_size)
if device:
batch_encoded = batch_encoded.to(device)
batch_encoded = batch_encoded["input_ids"]
calib_dataloader = DataLoader(batch_encoded,
batch_size=batch_size,
shuffle=False)
return calib_dataloader
def quantize_model(model, quant_cfg, calib_dataloader=None):
def calibrate_loop():
if calib_dataloader is None:
return
"""Adjusts weights and scaling factors based on selected algorithms."""
for idx, data in enumerate(calib_dataloader):
print(f"Calibrating batch {idx}")
model(data)
print("Starting quantization...")
start_time = time.time()
atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
end_time = time.time()
print("Quantization done. Total time used: {:.2f} s.".format(end_time -
start_time))
return model
def main(args):
if not torch.cuda.is_available():
raise OSError("GPU is required for inference.")
random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
model = get_model(args.model_dir, args.dtype, args.device)
model_type = get_model_type(model)
tokenizer = get_tokenizer(args.model_dir, model_type=model_type)
if args.qformat in ["full_prec", "int8_wo", "int4_wo"
] and args.kv_cache_dtype is None:
print(f"No quantization applied, export {args.dtype} model")
else:
if "awq" in args.qformat:
if args.calib_size > 32:
print("AWQ calibration could take longer with calib_size = "
f"{args.calib_size}, Using calib_size=32 instead")
args.calib_size = 32
print("\nAWQ calibration could take longer than other calibration "
"methods. Please increase the batch size to speed up the "
"calibration process. Batch size can be set by adding the "
"argument --batch_size <batch_size> to the command line.\n")
calib_dataloader = get_calib_dataloader(
tokenizer=tokenizer,
batch_size=args.batch_size,
calib_size=args.calib_size,
device=args.device,
)
if args.qformat in QUANT_CFG_CHOICES:
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
else:
raise ValueError(
f"Unsupported quantization format: {args.qformat}")
if "awq" in args.qformat:
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
weight_quantizer = quant_cfg["quant_cfg"][
"*weight_quantizer"] # type: ignore
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
weight_quantizer["block_sizes"][-1] = args.awq_block_size
if args.kv_cache_dtype is not None:
if args.kv_cache_dtype == "fp8":
for value in KV_CACHE_CFG.values():
value.update({"num_bits": (4, 3)}) # type: ignore
quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
print(quant_cfg)
model = quantize_model(model, quant_cfg, calib_dataloader)
with torch.inference_mode():
if model_type is None:
print(f"Unknown model type {type(model).__name__}. Continue "
"exporting...")
model_type = f"unknown:{type(model).__name__}"
export_path = args.output_dir
start_time = time.time()
if args.qformat == "int4_awq" and model_type == "qwen":
torch.save(model.state_dict(), export_path)
else:
export_npz = (model_type not in [
'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan'
])
# export safetensors
export_model_config(
model,
model_type,
getattr(torch, args.dtype),
export_dir=export_path,
inference_tensor_parallel=args.tp_size,
inference_pipeline_parallel=args.pp_size,
# export_tensorrt_llm_config=(not export_npz),
export_tensorrt_llm_config=False,
export_npz=export_npz)
# Workaround for wo quantization
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
with open(f"{export_path}/config.json") as f:
tensorrt_llm_config = json.load(f)
if args.qformat == "int8_wo":
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
elif args.qformat == "int4_wo":
tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16'
else:
tensorrt_llm_config["quantization"]["quant_algo"] = None
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)
end_time = time.time()
print("Quantized model exported to {} \nTotal time used {:.2f} s.".
format(export_path, end_time - start_time))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model-dir",
help="Specify where the HuggingFace model is",
required=True)
parser.add_argument("--device", default="cuda")
parser.add_argument("--dtype", help="Model data type.", default="float16")
parser.add_argument(
"--qformat",
help="Quantization format.",
default="full_prec",
choices=[
"fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo",
"full_prec"
],
)
parser.add_argument("--batch-size",
help="Batch size for calibration.",
type=int,
default=1)
parser.add_argument("--calib-size",
help="Number of samples for calibration.",
type=int,
default=512)
parser.add_argument("--output-dir", default="exported_model")
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--pp-size", type=int, default=1)
parser.add_argument("--awq-block-size", type=int, default=128)
parser.add_argument("--kv-cache-dtype",
help="KV Cache dtype.",
default=None,
choices=["int8", "fp8", None])
args = parser.parse_args()
main(args)

View File

@ -1,90 +0,0 @@
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0230364128947258,
"1": 0.01979283057153225,
"2": 0.0241350457072258,
"3": 0.0308314748108387,
"4": 0.0430733822286129,
"5": 0.0370396226644516,
"6": 0.0306222103536129,
"7": 0.0357491634786129,
"8": 0.0358189195394516,
"9": 0.0443289652466774,
"10": 0.0433175228536129,
"11": 0.0416782945394516,
"12": 0.0366908498108387,
"13": 0.0432477705180645,
"14": 0.0410505048930645,
"15": 0.0457589291036129,
"16": 0.0418526791036129,
"17": 0.0432477705180645,
"18": 0.0469447560608387,
"19": 0.0514787957072258,
"20": 0.0541294664144516,
"21": 0.0587681382894516,
"22": 0.0625,
"23": 0.0585588738322258,
"24": 0.0600237175822258,
"25": 0.0588030144572258,
"26": 0.0531180277466774,
"27": 0.06396484375,
"28": 0.0603027381002903,
"29": 0.0582101047039032,
"30": 0.0625348836183548,
"31": 0.0585588738322258,
"32": 0.0582798570394516,
"33": 0.0575125589966774,
"34": 0.0590820349752903,
"35": 0.0614188089966774,
"36": 0.0631975457072258,
"37": 0.0615931935608387,
"38": 0.0601283498108387,
"39": 0.0571986623108387,
"40": 0.0670340433716774,
"41": 0.0523507259786129,
"42": 0.0547223798930645,
"43": 0.0631975457072258,
"44": 0.0663713738322258,
"45": 0.0603376142680645,
"46": 0.0652204304933548,
"47": 0.0734514519572258,
"48": 0.0693708211183548,
"49": 0.0725446492433548,
"50": 0.0627790242433548,
"51": 0.0691266804933548,
"52": 0.0688825398683548,
"53": 0.068429134786129,
"54": 0.0605119988322258,
"55": 0.0799386203289032,
"56": 0.0853097140789032,
"57": 0.0661969929933548,
"58": 0.0689871683716774,
"59": 0.0724051371216774,
"60": 0.0541643425822258,
"61": 0.0626743882894516,
"62": 0.0628487765789032,
"63": 0.0607212632894516,
"64": 0.0589076466858387,
"65": 0.0451660193502903,
"66": 0.0453055277466774,
"67": 0.0414341539144516,
"68": 0.0385044664144516,
"69": 0.0414341539144516,
"70": 0.0466308631002903,
"71": 0.0399693101644516,
"72": 0.0437011756002903,
"73": 0.0434221550822258,
"74": 0.0428989976644516,
"75": 0.0401785746216774,
"76": 0.0431082621216774,
"77": 0.0484444759786129,
"78": 0.0417829267680645,
"79": 0.0418178029358387
}
}
}
}

View File

@ -1,42 +0,0 @@
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0152239128947258,
"1": 0.0188860222697258,
"2": 0.0354178324341774,
"3": 0.0376674123108387,
"4": 0.0418526791036129,
"5": 0.0433175228536129,
"6": 0.0397600457072258,
"7": 0.0424455925822258,
"8": 0.0415387861430645,
"9": 0.0408412404358387,
"10": 0.0395856611430645,
"11": 0.0377371683716774,
"12": 0.0400739423930645,
"13": 0.040771484375,
"14": 0.0393415205180645,
"15": 0.0369001142680645,
"16": 0.03857421875,
"17": 0.0387486070394516,
"18": 0.0403180830180645,
"19": 0.0396205373108387,
"20": 0.0375627800822258,
"21": 0.0407366082072258,
"22": 0.0432477705180645,
"23": 0.0377022884786129,
"24": 0.0399693101644516,
"25": 0.0374581478536129,
"26": 0.0413295216858387,
"27": 0.0442243330180645,
"28": 0.0424804724752903,
"29": 0.0456891767680645,
"30": 0.0409109964966774,
"31": 0.0482352152466774
}
}
}
}

View File

@ -182,7 +182,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the paged attention kernel.
output = torch.empty_like(query)

View File

@ -210,7 +210,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
tp_rank = 0
# Call the paged attention kernel.

View File

@ -160,7 +160,7 @@ def test_reshape_and_cache(
cloned_value_cache = value_cache.clone()
# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
@ -258,8 +258,8 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches
k_scale = key.amax().item() / 256
v_scale = value.amax().item() / 256
k_scale = (key.amax() / 256.0).to(torch.float32)
v_scale = (value.amax() / 256.0).to(torch.float32)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
@ -284,12 +284,12 @@ def test_reshape_and_cache_flash(
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache,
key_cache,
k_scale,
k_scale.item(),
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache,
value_cache,
v_scale,
v_scale.item(),
kv_dtype=kv_cache_dtype)
# Run the reference implementation.

View File

@ -138,6 +138,7 @@ def test_contexted_kv_attention(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
@ -153,6 +154,8 @@ def test_contexted_kv_attention(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
start_time = time.time()
@ -168,6 +171,8 @@ def test_contexted_kv_attention(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
end_time = time.time()
@ -366,6 +371,7 @@ def test_contexted_kv_attention_alibi(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
@ -381,6 +387,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
start_time = time.time()
@ -396,6 +404,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
end_time = time.time()

View File

@ -909,6 +909,7 @@ def make_test_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
@ -958,6 +959,7 @@ def make_test_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,

View File

@ -19,18 +19,17 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize(
"kv_cache_dtype,base_model,test_model,scale_path",
"kv_cache_dtype,base_model,test_model",
[
# Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors.
("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama-3.2-1B-Instruct-FP8-KV", None),
"nm-testing/Llama-3.2-1B-Instruct-FP8-KV"),
# Test FP16 checkpoint w. fp8_e5m2 kv-cache.
("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct", None),
"meta-llama/Llama-3.2-1B-Instruct"),
# Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json.
("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-7b-chat-hf",
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
"meta-llama/Llama-2-7b-chat-hf")
])
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
@pytest.mark.parametrize("max_tokens", [4])
@ -48,7 +47,6 @@ def test_models(
kv_cache_dtype: str,
base_model: str,
test_model: str,
scale_path: Optional[str],
max_tokens: int,
enforce_eager: bool,
backend: str,
@ -76,10 +74,6 @@ def test_models(
baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS)
extra_kwargs = {}
if scale_path is not None:
extra_kwargs["quantization_param_path"] = scale_path
with vllm_runner(
test_model,
max_model_len=MAX_MODEL_LEN,
@ -87,7 +81,6 @@ def test_models(
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
**extra_kwargs,
) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS)

View File

@ -74,6 +74,7 @@ def test_model_runner_input():
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
@ -126,6 +127,7 @@ def test_embedding_model_runner_input():
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
@ -177,6 +179,7 @@ def test_multi_step_model_runner_input():
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),

View File

@ -48,8 +48,8 @@ def paged_attention_v1(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
@ -80,8 +80,8 @@ def paged_attention_v2(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
@ -112,8 +112,8 @@ def paged_attention_rocm(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
@ -956,8 +956,8 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
value_cache, slot_mapping,
@ -971,8 +971,8 @@ def reshape_and_cache_flash(
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
value_cache, slot_mapping,

View File

@ -123,6 +123,10 @@ class AttentionMetadata:
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
@ -226,8 +230,10 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
class AttentionLayer(Protocol):
_k_scale: float
_v_scale: float
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_k_scale_float: float
_v_scale_float: float
def forward(
self,

View File

@ -222,6 +222,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
@ -251,6 +252,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,

View File

@ -230,6 +230,7 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
@ -274,6 +275,7 @@ class FlashAttentionMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
@ -557,6 +559,7 @@ class FlashAttentionMetadataBuilder(
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
@ -675,7 +678,7 @@ class FlashAttentionImpl(AttentionImpl):
NOTE: It in-place updates the output tensor.
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."

View File

@ -219,6 +219,7 @@ class FlashInferState(AttentionState):
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
@ -733,6 +734,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
@ -888,8 +890,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
@ -899,8 +901,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
window_left=window_left)
if prefill_output is None and decode_output is not None:

View File

@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)

View File

@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)

View File

@ -140,6 +140,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_decode_query_len=0,
@ -173,6 +174,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len,
@ -380,6 +382,7 @@ class PlaceholderAttentionMetadataBuilder(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,

View File

@ -153,6 +153,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
@ -182,6 +183,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,

View File

@ -379,6 +379,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)
return attn_metadata
@ -454,7 +455,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):

View File

@ -265,6 +265,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
@ -317,6 +318,7 @@ class CommonAttentionState(AttentionState):
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,

View File

@ -218,6 +218,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
@ -262,6 +263,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config
@ -57,10 +58,12 @@ class Attention(nn.Module):
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
is_attention_free = cache_config.is_attention_free
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
@ -70,8 +73,15 @@ class Attention(nn.Module):
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._k_scale = 1.0
self._v_scale = 1.0
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)
self._k_scale_float = 1.0
self._v_scale_float = 1.0
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
@ -127,6 +137,9 @@ class Attention(nn.Module):
).parallel_config.pipeline_parallel_size)
]
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def forward(
self,
query: torch.Tensor,
@ -135,6 +148,9 @@ class Attention(nn.Module):
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
if self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
@ -161,6 +177,14 @@ class Attention(nn.Module):
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name)
def calc_kv_scales(self, key, value):
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
self.calculate_kv_scales = False
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore

View File

@ -52,8 +52,8 @@ class _PagedAttention:
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ops.reshape_and_cache(
@ -80,8 +80,8 @@ class _PagedAttention:
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
tp_rank: int = 0
@ -149,8 +149,8 @@ class _IPEXPagedAttention(_PagedAttention):
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
@ -170,8 +170,8 @@ class _IPEXPagedAttention(_PagedAttention):
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
block_size = value_cache.shape[2]

View File

@ -69,8 +69,8 @@ class PagedAttention:
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
ops.reshape_and_cache(
key,
@ -95,8 +95,8 @@ class PagedAttention:
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
@ -204,8 +204,8 @@ class PagedAttention:
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(

View File

@ -133,7 +133,7 @@ if triton.__version__ >= "2.1.0":
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
@ -181,7 +181,7 @@ if triton.__version__ >= "2.1.0":
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
@ -564,7 +564,7 @@ if triton.__version__ >= "2.1.0":
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
@ -604,7 +604,7 @@ if triton.__version__ >= "2.1.0":
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
@ -713,8 +713,8 @@ if triton.__version__ >= "2.1.0":
b_seq_len,
b_ctx_len,
max_input_len,
k_scale: float = 1.0,
v_scale: float = 1.0,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
alibi_slopes=None,
sliding_window=None):

View File

@ -120,11 +120,6 @@ class ModelConfig:
decoding draft models.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
@ -187,7 +182,6 @@ class ModelConfig:
factors.append(self.model)
factors.append(self.dtype)
factors.append(self.quantization)
factors.append(self.quantization_param_path)
factors.append(self.revision)
factors.append(self.code_revision)
factors.append(self.trust_remote_code)
@ -213,7 +207,6 @@ class ModelConfig:
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
@ -274,7 +267,6 @@ class ModelConfig:
else:
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
@ -1002,6 +994,7 @@ class CacheConfig:
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
cpu_offload_gb: float = 0,
calculate_kv_scales: Optional[bool] = None,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
@ -1012,7 +1005,7 @@ class CacheConfig:
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb
self.calculate_kv_scales = calculate_kv_scales
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
@ -1021,6 +1014,10 @@ class CacheConfig:
self.num_gpu_blocks: Optional[int] = None
self.num_cpu_blocks: Optional[int] = None
# Set calculate_kv_scales to False if the value is unset.
if self.calculate_kv_scales is None:
self.calculate_kv_scales = False
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
@ -3297,7 +3294,6 @@ class VllmConfig:
f"quantization={self.model_config.quantization}, "
f"enforce_eager={self.model_config.enforce_eager}, "
f"kv_cache_dtype={self.cache_config.cache_dtype}, "
f"quantization_param_path={self.model_config.quantization_param_path},"
f" device_config={self.device_config.device}, "
f"decoding_config={self.decoding_config!r}, "
f"observability_config={self.observability_config!r}, "

View File

@ -98,7 +98,6 @@ class EngineArgs:
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
@ -199,6 +198,8 @@ class EngineArgs:
generation_config: Optional[str] = None
enable_sleep_mode: bool = False
calculate_kv_scales: Optional[bool] = None
def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
@ -350,17 +351,6 @@ class EngineArgs:
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
default=None,
help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when '
'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version '
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
parser.add_argument('--max-model-len',
type=int,
default=EngineArgs.max_model_len,
@ -962,6 +952,15 @@ class EngineArgs:
help="Enable sleep mode for the engine. "
"(only cuda platform is supported)")
parser.add_argument(
'--calculate-kv-scales',
action='store_true',
help='This enables dynamic calculation of '
'k_scale and v_scale when kv-cache-dtype is fp8. '
'If calculate-kv-scales is false, the scales will '
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')
return parser
@classmethod
@ -991,7 +990,6 @@ class EngineArgs:
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
@ -1068,6 +1066,7 @@ class EngineArgs:
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
)
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,

View File

@ -73,6 +73,8 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
@ -474,6 +476,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
"V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),

View File

@ -3,6 +3,7 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -40,11 +41,16 @@ class BaseKVCacheMethod(QuantizeMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if current_platform.is_rocm():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
@ -58,6 +64,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if current_platform.is_rocm():
k_scale *= 2
v_scale *= 2
if not isinstance(k_scale, float) or not isinstance(
v_scale, float):
@ -65,9 +74,11 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"for fp8 KV cache")
# These are used in the final Attention.forward()
layer._k_scale = k_scale
layer._v_scale = v_scale
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
layer._k_scale_float = k_scale
layer._v_scale_float = v_scale
if (k_scale == 1.0 and v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype):
logger.warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "

View File

@ -6,8 +6,7 @@ import json
import os
import tempfile
from collections import defaultdict
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
Tuple, Union)
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import filelock
import gguf
@ -23,7 +22,6 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.platforms import current_platform
from vllm.utils import PlaceholderModule
@ -496,47 +494,6 @@ def gguf_quant_weights_iterator(
yield name, param
def kv_cache_scales_loader(
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
Keep this function in sync with the output of
examples/other/fp8/extract_scales.py
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct,
context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename)
except Exception:
logger.exception("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 for all "
"layers in TP rank %d as an error occurred during loading.", tp_rank)
return []
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor

View File

@ -30,8 +30,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig
@ -576,32 +574,3 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.transformer.h[layer_idx], nn.Identity):
layer_self_attn = self.transformer.h[layer_idx].attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

View File

@ -29,8 +29,7 @@ from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@ -518,29 +516,3 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

View File

@ -29,8 +29,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -43,9 +42,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@ -440,32 +438,6 @@ class LlamaModel(nn.Module):
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
@ -593,9 +565,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(

View File

@ -831,6 +831,7 @@ class MllamaTextCrossAttention(nn.Module):
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
@ -843,8 +844,8 @@ class MllamaTextCrossAttention(nn.Module):
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
i,
i,
)
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache(
@ -853,7 +854,7 @@ class MllamaTextCrossAttention(nn.Module):
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
attn_metadata.cross_slot_mapping, "auto", i, i)
else:
raise ValueError(
f"Unsupported Attention backend {self.attn.backend} "

View File

@ -30,8 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@ -535,32 +533,3 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

View File

@ -166,10 +166,6 @@ class FlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:

View File

@ -903,7 +903,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here
None, # FIXME(kzawora): mutli-modality will not work here
enable_kv_scales_calculation=False,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
@ -1057,7 +1058,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None)
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,

View File

@ -3,7 +3,6 @@ import gc
import inspect
import itertools
import time
import warnings
import weakref
from contextlib import contextmanager
from dataclasses import dataclass
@ -41,7 +40,6 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
@ -1151,34 +1149,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))
if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm()
or current_platform.is_cuda()):
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2)
self.model.load_kv_cache_scales(
self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s",
self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__)
else:
logger.warning(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
backend = self.vllm_config.compilation_config.init_backend(
@ -1366,6 +1336,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype,
device=self.device)
# Disable KV Scale Calculation for dummy data during profile run
if model_input.attn_metadata is not None:
model_input.attn_metadata.enable_kv_scales_calculation = False
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
return
@ -1510,7 +1484,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
batch_size,
is_encoder_decoder_model=self.model_config.
is_encoder_decoder))
# Disable KV Scale Calculation for graph capture
attn_metadata.enable_kv_scales_calculation = False
if self.lora_config:
lora_mapping = LoRAMapping(
**dict(index_mapping=[0] * batch_size,

View File

@ -282,6 +282,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)

View File

@ -190,6 +190,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=None,
context_lens=None,
effective_query_lens=None,
@ -208,6 +209,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
@ -239,6 +241,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
@ -425,6 +428,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=prompt_lens,
@ -496,6 +500,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)

View File

@ -261,6 +261,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
is_prompt=True,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=seqlen_q,
max_seqlen=max_seqlen,
@ -345,6 +346,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
is_prompt=False,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=torch.tensor([]),
max_seqlen=0,