[Frontend][Bugfix] support prefill decode disaggregation on deepseek (#14824)
Signed-off-by: billishyahao <bill.he@amd.com> Co-authored-by: Zhai Feiyue <80079571+ZhaiFeiyue@users.noreply.github.com>
This commit is contained in:
parent
bfe2fe0af4
commit
742369d35a
@ -8,6 +8,9 @@ set -xe
|
||||
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
|
||||
sleep 1
|
||||
|
||||
# meta-llama/Meta-Llama-3.1-8B-Instruct or deepseek-ai/DeepSeek-V2-Lite
|
||||
MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'cleanup' INT
|
||||
|
||||
@ -44,18 +47,20 @@ wait_for_server() {
|
||||
# You can also adjust --kv-ip and --kv-port for distributed inference.
|
||||
|
||||
# prefilling instance, which is the KV producer
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
|
||||
--port 8100 \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--trust-remote-code \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
|
||||
|
||||
# decoding instance, which is the KV consumer
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
|
||||
--port 8200 \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--trust-remote-code \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
|
||||
|
||||
@ -78,7 +83,7 @@ sleep 1
|
||||
output1=$(curl -X POST -s http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "'"$MODEL_NAME"'",
|
||||
"prompt": "San Francisco is a",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0
|
||||
@ -87,7 +92,7 @@ output1=$(curl -X POST -s http://localhost:8000/v1/completions \
|
||||
output2=$(curl -X POST -s http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "'"$MODEL_NAME"'",
|
||||
"prompt": "Santa Clara is a",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0
|
||||
|
@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
@ -37,6 +38,8 @@ class SimpleConnector(KVConnectorBase):
|
||||
|
||||
self.config = config.kv_transfer_config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
|
||||
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
||||
@ -167,8 +170,26 @@ class SimpleConnector(KVConnectorBase):
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = getattr(model_config, "head_dim",
|
||||
int(hidden_size // num_attention_heads))
|
||||
|
||||
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||
# kv_lora_rank + qk_rope_head_dim].
|
||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||
# For more details, see vllm/attention/backends/mla/common.py.
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = model_config.kv_lora_rank + \
|
||||
model_config.qk_rope_head_dim
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = model_config.qk_nope_head_dim + \
|
||||
model_config.qk_rope_head_dim
|
||||
else:
|
||||
head_size = getattr(model_config, "head_dim",
|
||||
int(hidden_size // num_attention_heads))
|
||||
|
||||
# query_lens contains new KV caches that are added to vLLM.
|
||||
# so we will send them to decode instance
|
||||
@ -192,8 +213,12 @@ class SimpleConnector(KVConnectorBase):
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
else:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
@ -223,6 +248,8 @@ class SimpleConnector(KVConnectorBase):
|
||||
# and hidden states.
|
||||
bypass_model_exec = True
|
||||
|
||||
model_config = model_executable.model.config
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
@ -291,19 +318,35 @@ class SimpleConnector(KVConnectorBase):
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
layer = model_executable.model.layers[i]
|
||||
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device),
|
||||
values[i - model_executable.model.start_layer].to(
|
||||
value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||
k_c_normed_k_pe = keys[
|
||||
i - model_executable.model.start_layer].to(
|
||||
kv_cache.device).squeeze(1)
|
||||
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
|
||||
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device),
|
||||
values[i - model_executable.model.start_layer].to(
|
||||
value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
|
@ -589,6 +589,7 @@ class DeepseekV2Model(nn.Module):
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user