[Bugfix] Add kv_scale input parameter to CPU backend (#3840)
This commit is contained in:
parent
537ee25f43
commit
498eb5cfa3
@ -419,7 +419,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
|
|||||||
torch::Tensor &context_lens, int block_size,
|
torch::Tensor &context_lens, int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||||
const std::string &kv_cache_dtype) {
|
const std::string &kv_cache_dtype, float kv_scale) {
|
||||||
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||||
[&] {
|
[&] {
|
||||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||||
@ -734,7 +735,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
|
|||||||
torch::Tensor &context_lens, int block_size,
|
torch::Tensor &context_lens, int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||||
const std::string &kv_cache_dtype) {
|
const std::string &kv_cache_dtype, float kv_scale) {
|
||||||
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||||
[&] {
|
[&] {
|
||||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||||
|
@ -111,7 +111,9 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
|||||||
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
||||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||||
torch::Tensor &slot_mapping,
|
torch::Tensor &slot_mapping,
|
||||||
const std::string &kv_cache_dtype) {
|
const std::string &kv_cache_dtype, float kv_scale) {
|
||||||
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_size = key.size(2);
|
||||||
|
@ -114,6 +114,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: TorchSDPAMetadata,
|
attn_metadata: TorchSDPAMetadata,
|
||||||
|
kv_scale: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
|
|
||||||
@ -138,7 +139,8 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
attn_metadata.kv_cache_dtype)
|
attn_metadata.kv_cache_dtype,
|
||||||
|
kv_scale)
|
||||||
|
|
||||||
if attn_metadata.is_prompt:
|
if attn_metadata.is_prompt:
|
||||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||||
@ -199,6 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
|
kv_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
|
@ -97,7 +97,7 @@ class PagedAttention:
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_scale,
|
kv_scale: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user