[Bugfix] Add kv_scale input parameter to CPU backend (#3840)

This commit is contained in:
Woosuk Kwon 2024-04-03 21:33:08 -07:00 committed by GitHub
parent 537ee25f43
commit 498eb5cfa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 5 deletions

View File

@ -419,7 +419,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &context_lens, int block_size, torch::Tensor &context_lens, int block_size,
int max_context_len, int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes, const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype) { const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] { [&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
@ -734,7 +735,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &context_lens, int block_size, torch::Tensor &context_lens, int block_size,
int max_context_len, int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes, const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype) { const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] { [&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)

View File

@ -111,7 +111,9 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor &key_cache, torch::Tensor &value_cache,
torch::Tensor &slot_mapping, torch::Tensor &slot_mapping,
const std::string &kv_cache_dtype) { const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);

View File

@ -114,6 +114,7 @@ class TorchSDPABackendImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, attn_metadata: TorchSDPAMetadata,
kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
@ -138,7 +139,8 @@ class TorchSDPABackendImpl(AttentionImpl):
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype) attn_metadata.kv_cache_dtype,
kv_scale)
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
@ -199,6 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.

View File

@ -97,7 +97,7 @@ class PagedAttention:
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_scale, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty_like(query) output = torch.empty_like(query)