[Feature] add model aware kv ops helper (#16020)

Signed-off-by: billishyahao <bill.he@amd.com>
This commit is contained in:
billishyahao 2025-04-16 14:00:43 +08:00 committed by GitHub
parent 966c742ed2
commit 3ac98edcb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 121 additions and 99 deletions

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
MooncakeStore Connector for Distributed Machine Learning Inference MooncakeStore Connector for Distributed Machine Learning Inference
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
(KV cache producer) and decode vLLM workers (KV cache consumer) using a (KV cache producer) and decode vLLM workers (KV cache consumer) using a
database-style KVStore. database-style KVStore.
@ -11,9 +10,10 @@ from typing import TYPE_CHECKING, List, Tuple, Union
import torch import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.utils import (
model_aware_kv_ops_helper as kv_helper)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -32,8 +32,7 @@ class MooncakeStoreConnector(KVConnectorBase):
config: VllmConfig, config: VllmConfig,
): ):
self.config = config.kv_transfer_config self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size self.kv_helper = kv_helper(config)
self.local_tp_rank = local_rank self.local_tp_rank = local_rank
# Init kv_store # Init kv_store
@ -80,12 +79,7 @@ class MooncakeStoreConnector(KVConnectorBase):
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer end_layer = model_executable.model.end_layer
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
for idx, slen in enumerate(seq_lens): for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx]) start_pos = sum(seq_lens[:idx])
@ -97,10 +91,8 @@ class MooncakeStoreConnector(KVConnectorBase):
for layer_id in range(start_layer, end_layer): for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer] kv_cache = kv_caches[layer_id - start_layer]
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
key_cache = kv_cache[0].reshape(-1, num_heads, head_size) kv_cache, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos] current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0)) keys.append(key_cache[current_slot_mapping].unsqueeze(0))
@ -173,22 +165,15 @@ class MooncakeStoreConnector(KVConnectorBase):
layer = model_executable.model.layers[layer_id] layer = model_executable.model.layers[layer_id]
# get kvcache object # get kvcache object
kv_cache = kv_caches[layer_id - start_layer] kv_cache = kv_caches[layer_id - start_layer]
key_cache, value_cache = kv_cache[0], kv_cache[1]
# get remote kvcache
# get remote kvcache
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
layer_id] layer_id]
# use ops.reshape_and_cache_flash to put kv into kvcache
ops.reshape_and_cache_flash( self.kv_helper.put_kv_to_cache(model_executable, remote_k,
remote_k.to(key_cache.device), remote_v, layer, kv_cache,
remote_v.to(value_cache.device), slot_mapping, start_pos,
key_cache, end_pos)
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
hidden_or_intermediate_states_for_one_req.append(hidden) hidden_or_intermediate_states_for_one_req.append(hidden)

View File

@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.utils import (
model_aware_kv_ops_helper as kv_helper)
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer) SimpleBuffer)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -37,9 +37,7 @@ class SimpleConnector(KVConnectorBase):
): ):
self.config = config.kv_transfer_config self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size self.kv_helper = kv_helper(config)
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
if self.config.kv_connector == "PyNcclConnector": if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
@ -165,31 +163,7 @@ class SimpleConnector(KVConnectorBase):
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer end_layer = model_executable.model.end_layer
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + \
model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
# query_lens contains new KV caches that are added to vLLM. # query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance # so we will send them to decode instance
@ -212,13 +186,8 @@ class SimpleConnector(KVConnectorBase):
for layer_id in range(start_layer, end_layer): for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer] kv_cache = kv_caches[layer_id - start_layer]
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
if self.is_deepseek_mla and self.use_mla_opt: kv_cache, num_heads, head_size)
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos] current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
@ -248,12 +217,12 @@ class SimpleConnector(KVConnectorBase):
# and hidden states. # and hidden states.
bypass_model_exec = True bypass_model_exec = True
model_config = model_executable.model.config
input_tokens_tensor = model_input.input_tokens input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten() slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
hidden_or_intermediate_states_for_one_req = [] hidden_or_intermediate_states_for_one_req = []
@ -312,41 +281,19 @@ class SimpleConnector(KVConnectorBase):
end_pos = start_pos + num_computed_tokens end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory # put received KV caches into paged memory
for i in range(model_executable.model.start_layer, for cur_layer in range(start_layer, end_layer):
model_executable.model.end_layer):
kv_cache = kv_caches[i - model_executable.model.start_layer] layer_id = cur_layer - start_layer
layer = model_executable.model.layers[i] kv_cache = kv_caches[layer_id]
layer = model_executable.model.layers[cur_layer]
if self.is_deepseek_mla and self.use_mla_opt: # get remote kvcache
layer.self_attn.attn = layer.self_attn.mla_attn remote_k, remote_v = keys[layer_id], values[layer_id]
k_c_normed_k_pe = keys[
i - model_executable.model.start_layer].to( self.kv_helper.put_kv_to_cache(model_executable, remote_k,
kv_cache.device).squeeze(1) remote_v, layer, kv_cache,
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] slot_mapping, start_pos,
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] end_pos)
ops.concat_and_cache_mla(
k_c_normed,
k_pe,
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
hidden_or_intermediate_states_for_one_req.append(hidden) hidden_or_intermediate_states_for_one_req.append(hidden)

View File

@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
"""
KV cache helper for store.
"""
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
class model_aware_kv_ops_helper:
def __init__(self, config: VllmConfig):
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
self.tp_size = config.parallel_config.tensor_parallel_size
def get_model_args(self, model_executable: torch.nn.Module):
model_config = model_executable.model.config
self.model_executable = model_executable
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + \
model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
return num_heads, head_size
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
return key_cache, value_cache
def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
layer, kv_cache, slot_mapping, start_pos, end_pos):
model_config = model_executable.model.config
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
k_c_normed_k_pe = keys.squeeze(1)
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
ops.concat_and_cache_mla(
k_c_normed.to(kv_cache.device),
k_pe.to(kv_cache.device),
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys.to(key_cache.device),
values.to(value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)