[ROCm][V1] Update reshape_and_cache to properly work with CUDA graph padding (#13922)

This commit is contained in:
Sage Moore 2025-02-26 20:04:12 -08:00 committed by GitHub
parent c9944acbf9
commit 378b3ef6f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -375,7 +375,7 @@ void reshape_and_cache(
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_tokens = key.size(0);
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);