[Feature] add model aware kv ops helper (#16020)
Signed-off-by: billishyahao <bill.he@amd.com>
This commit is contained in:
parent
966c742ed2
commit
3ac98edcb1
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""
|
"""
|
||||||
MooncakeStore Connector for Distributed Machine Learning Inference
|
MooncakeStore Connector for Distributed Machine Learning Inference
|
||||||
|
|
||||||
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
|
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
|
||||||
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
|
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
|
||||||
database-style KVStore.
|
database-style KVStore.
|
||||||
@ -11,9 +10,10 @@ from typing import TYPE_CHECKING, List, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
|
model_aware_kv_ops_helper as kv_helper)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@ -32,8 +32,7 @@ class MooncakeStoreConnector(KVConnectorBase):
|
|||||||
config: VllmConfig,
|
config: VllmConfig,
|
||||||
):
|
):
|
||||||
self.config = config.kv_transfer_config
|
self.config = config.kv_transfer_config
|
||||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
self.kv_helper = kv_helper(config)
|
||||||
|
|
||||||
self.local_tp_rank = local_rank
|
self.local_tp_rank = local_rank
|
||||||
|
|
||||||
# Init kv_store
|
# Init kv_store
|
||||||
@ -80,12 +79,7 @@ class MooncakeStoreConnector(KVConnectorBase):
|
|||||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||||
start_layer = model_executable.model.start_layer
|
start_layer = model_executable.model.start_layer
|
||||||
end_layer = model_executable.model.end_layer
|
end_layer = model_executable.model.end_layer
|
||||||
|
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
|
||||||
model_config = model_executable.model.config
|
|
||||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
|
||||||
hidden_size = model_config.hidden_size
|
|
||||||
num_attention_heads = model_config.num_attention_heads
|
|
||||||
head_size = int(hidden_size / num_attention_heads)
|
|
||||||
|
|
||||||
for idx, slen in enumerate(seq_lens):
|
for idx, slen in enumerate(seq_lens):
|
||||||
start_pos = sum(seq_lens[:idx])
|
start_pos = sum(seq_lens[:idx])
|
||||||
@ -97,10 +91,8 @@ class MooncakeStoreConnector(KVConnectorBase):
|
|||||||
|
|
||||||
for layer_id in range(start_layer, end_layer):
|
for layer_id in range(start_layer, end_layer):
|
||||||
kv_cache = kv_caches[layer_id - start_layer]
|
kv_cache = kv_caches[layer_id - start_layer]
|
||||||
|
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
|
||||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
kv_cache, num_heads, head_size)
|
||||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
|
||||||
|
|
||||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||||
|
|
||||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||||
@ -173,22 +165,15 @@ class MooncakeStoreConnector(KVConnectorBase):
|
|||||||
layer = model_executable.model.layers[layer_id]
|
layer = model_executable.model.layers[layer_id]
|
||||||
# get kvcache object
|
# get kvcache object
|
||||||
kv_cache = kv_caches[layer_id - start_layer]
|
kv_cache = kv_caches[layer_id - start_layer]
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
||||||
# get remote kvcache
|
|
||||||
|
|
||||||
|
# get remote kvcache
|
||||||
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
|
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
|
||||||
layer_id]
|
layer_id]
|
||||||
# use ops.reshape_and_cache_flash to put kv into kvcache
|
|
||||||
ops.reshape_and_cache_flash(
|
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
|
||||||
remote_k.to(key_cache.device),
|
remote_v, layer, kv_cache,
|
||||||
remote_v.to(value_cache.device),
|
slot_mapping, start_pos,
|
||||||
key_cache,
|
end_pos)
|
||||||
value_cache,
|
|
||||||
slot_mapping[start_pos:end_pos],
|
|
||||||
layer.self_attn.attn.kv_cache_dtype,
|
|
||||||
layer.self_attn.attn._k_scale,
|
|
||||||
layer.self_attn.attn._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||||
|
|
||||||
|
@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
|
model_aware_kv_ops_helper as kv_helper)
|
||||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
||||||
SimpleBuffer)
|
SimpleBuffer)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -37,9 +37,7 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
):
|
):
|
||||||
|
|
||||||
self.config = config.kv_transfer_config
|
self.config = config.kv_transfer_config
|
||||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
self.kv_helper = kv_helper(config)
|
||||||
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
|
||||||
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
|
|
||||||
|
|
||||||
if self.config.kv_connector == "PyNcclConnector":
|
if self.config.kv_connector == "PyNcclConnector":
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
||||||
@ -165,31 +163,7 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||||
start_layer = model_executable.model.start_layer
|
start_layer = model_executable.model.start_layer
|
||||||
end_layer = model_executable.model.end_layer
|
end_layer = model_executable.model.end_layer
|
||||||
|
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
|
||||||
model_config = model_executable.model.config
|
|
||||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
|
||||||
hidden_size = model_config.hidden_size
|
|
||||||
num_attention_heads = model_config.num_attention_heads
|
|
||||||
|
|
||||||
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
|
||||||
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
|
||||||
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
|
||||||
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
|
||||||
# kv_lora_rank + qk_rope_head_dim].
|
|
||||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
|
||||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
|
||||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
|
||||||
# For more details, see vllm/attention/backends/mla/common.py.
|
|
||||||
if self.is_deepseek_mla and self.use_mla_opt:
|
|
||||||
head_size = model_config.kv_lora_rank + \
|
|
||||||
model_config.qk_rope_head_dim
|
|
||||||
num_heads = 1
|
|
||||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
|
||||||
head_size = model_config.qk_nope_head_dim + \
|
|
||||||
model_config.qk_rope_head_dim
|
|
||||||
else:
|
|
||||||
head_size = getattr(model_config, "head_dim",
|
|
||||||
int(hidden_size // num_attention_heads))
|
|
||||||
|
|
||||||
# query_lens contains new KV caches that are added to vLLM.
|
# query_lens contains new KV caches that are added to vLLM.
|
||||||
# so we will send them to decode instance
|
# so we will send them to decode instance
|
||||||
@ -212,13 +186,8 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
|
|
||||||
for layer_id in range(start_layer, end_layer):
|
for layer_id in range(start_layer, end_layer):
|
||||||
kv_cache = kv_caches[layer_id - start_layer]
|
kv_cache = kv_caches[layer_id - start_layer]
|
||||||
|
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
|
||||||
if self.is_deepseek_mla and self.use_mla_opt:
|
kv_cache, num_heads, head_size)
|
||||||
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
|
||||||
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
|
||||||
else:
|
|
||||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
|
||||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
|
||||||
|
|
||||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||||
|
|
||||||
@ -248,12 +217,12 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
# and hidden states.
|
# and hidden states.
|
||||||
bypass_model_exec = True
|
bypass_model_exec = True
|
||||||
|
|
||||||
model_config = model_executable.model.config
|
|
||||||
|
|
||||||
input_tokens_tensor = model_input.input_tokens
|
input_tokens_tensor = model_input.input_tokens
|
||||||
seq_lens = model_input.attn_metadata.seq_lens
|
seq_lens = model_input.attn_metadata.seq_lens
|
||||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||||
|
start_layer = model_executable.model.start_layer
|
||||||
|
end_layer = model_executable.model.end_layer
|
||||||
|
|
||||||
hidden_or_intermediate_states_for_one_req = []
|
hidden_or_intermediate_states_for_one_req = []
|
||||||
|
|
||||||
@ -312,41 +281,19 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
end_pos = start_pos + num_computed_tokens
|
end_pos = start_pos + num_computed_tokens
|
||||||
|
|
||||||
# put received KV caches into paged memory
|
# put received KV caches into paged memory
|
||||||
for i in range(model_executable.model.start_layer,
|
for cur_layer in range(start_layer, end_layer):
|
||||||
model_executable.model.end_layer):
|
|
||||||
|
|
||||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
layer_id = cur_layer - start_layer
|
||||||
layer = model_executable.model.layers[i]
|
kv_cache = kv_caches[layer_id]
|
||||||
|
layer = model_executable.model.layers[cur_layer]
|
||||||
|
|
||||||
if self.is_deepseek_mla and self.use_mla_opt:
|
# get remote kvcache
|
||||||
layer.self_attn.attn = layer.self_attn.mla_attn
|
remote_k, remote_v = keys[layer_id], values[layer_id]
|
||||||
k_c_normed_k_pe = keys[
|
|
||||||
i - model_executable.model.start_layer].to(
|
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
|
||||||
kv_cache.device).squeeze(1)
|
remote_v, layer, kv_cache,
|
||||||
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
|
slot_mapping, start_pos,
|
||||||
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
|
end_pos)
|
||||||
ops.concat_and_cache_mla(
|
|
||||||
k_c_normed,
|
|
||||||
k_pe,
|
|
||||||
kv_cache,
|
|
||||||
slot_mapping[start_pos:end_pos],
|
|
||||||
layer.self_attn.attn.kv_cache_dtype,
|
|
||||||
layer.self_attn.attn._k_scale,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
||||||
ops.reshape_and_cache_flash(
|
|
||||||
keys[i - model_executable.model.start_layer].to(
|
|
||||||
key_cache.device),
|
|
||||||
values[i - model_executable.model.start_layer].to(
|
|
||||||
value_cache.device),
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
slot_mapping[start_pos:end_pos],
|
|
||||||
layer.self_attn.attn.kv_cache_dtype,
|
|
||||||
layer.self_attn.attn._k_scale,
|
|
||||||
layer.self_attn.attn._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||||
|
|
||||||
|
90
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
90
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
KV cache helper for store.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class model_aware_kv_ops_helper:
|
||||||
|
|
||||||
|
def __init__(self, config: VllmConfig):
|
||||||
|
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||||
|
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
|
||||||
|
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
def get_model_args(self, model_executable: torch.nn.Module):
|
||||||
|
|
||||||
|
model_config = model_executable.model.config
|
||||||
|
self.model_executable = model_executable
|
||||||
|
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||||
|
hidden_size = model_config.hidden_size
|
||||||
|
num_attention_heads = model_config.num_attention_heads
|
||||||
|
|
||||||
|
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||||
|
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||||
|
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||||
|
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||||
|
# kv_lora_rank + qk_rope_head_dim].
|
||||||
|
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||||
|
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||||
|
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||||
|
# For more details, see vllm/attention/backends/mla/common.py.
|
||||||
|
if self.is_deepseek_mla and self.use_mla_opt:
|
||||||
|
head_size = model_config.kv_lora_rank + \
|
||||||
|
model_config.qk_rope_head_dim
|
||||||
|
num_heads = 1
|
||||||
|
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||||
|
head_size = model_config.qk_nope_head_dim + \
|
||||||
|
model_config.qk_rope_head_dim
|
||||||
|
else:
|
||||||
|
head_size = getattr(model_config, "head_dim",
|
||||||
|
int(hidden_size // num_attention_heads))
|
||||||
|
|
||||||
|
return num_heads, head_size
|
||||||
|
|
||||||
|
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
|
||||||
|
if self.is_deepseek_mla and self.use_mla_opt:
|
||||||
|
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||||
|
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||||
|
else:
|
||||||
|
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||||
|
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||||
|
return key_cache, value_cache
|
||||||
|
|
||||||
|
def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
|
||||||
|
layer, kv_cache, slot_mapping, start_pos, end_pos):
|
||||||
|
|
||||||
|
model_config = model_executable.model.config
|
||||||
|
|
||||||
|
if self.is_deepseek_mla and self.use_mla_opt:
|
||||||
|
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||||
|
k_c_normed_k_pe = keys.squeeze(1)
|
||||||
|
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
|
||||||
|
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
|
||||||
|
ops.concat_and_cache_mla(
|
||||||
|
k_c_normed.to(kv_cache.device),
|
||||||
|
k_pe.to(kv_cache.device),
|
||||||
|
kv_cache,
|
||||||
|
slot_mapping[start_pos:end_pos],
|
||||||
|
layer.self_attn.attn.kv_cache_dtype,
|
||||||
|
layer.self_attn.attn._k_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
ops.reshape_and_cache_flash(
|
||||||
|
keys.to(key_cache.device),
|
||||||
|
values.to(value_cache.device),
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping[start_pos:end_pos],
|
||||||
|
layer.self_attn.attn.kv_cache_dtype,
|
||||||
|
layer.self_attn.attn._k_scale,
|
||||||
|
layer.self_attn.attn._v_scale,
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user