Fix a bug in attention kernel (#68)

This commit is contained in:
Woosuk Kwon 2023-05-04 02:56:09 -07:00 committed by GitHub
parent e070829ae8
commit 130d5fd8c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -345,7 +345,7 @@ void single_query_cached_kv_attention_launcher(
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(T);
int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
int shared_mem_size = std::max(logits_size, outputs_size);