Fix a bug in attention kernel (#68)
This commit is contained in:
parent
e070829ae8
commit
130d5fd8c7
@ -345,7 +345,7 @@ void single_query_cached_kv_attention_launcher(
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_context_len * sizeof(T);
|
||||
int logits_size = padded_max_context_len * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user