[Bugfix] Add kv_scale input parameter to CPU backend (#3840)
This commit is contained in:
parent
537ee25f43
commit
498eb5cfa3
@ -419,7 +419,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
|
||||
torch::Tensor &context_lens, int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype) {
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||
@ -734,7 +735,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
|
||||
torch::Tensor &context_lens, int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype) {
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||
|
@ -111,7 +111,9 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
||||
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||
torch::Tensor &slot_mapping,
|
||||
const std::string &kv_cache_dtype) {
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
|
@ -114,6 +114,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: TorchSDPAMetadata,
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
@ -138,7 +139,8 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.kv_cache_dtype)
|
||||
attn_metadata.kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||
@ -199,6 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
kv_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
|
@ -97,7 +97,7 @@ class PagedAttention:
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_scale,
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user