From 742369d35a95999176e1fd0391646dde3aaa3ec3 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Thu, 20 Mar 2025 15:00:33 +0800 Subject: [PATCH] [Frontend][Bugfix] support prefill decode disaggregation on deepseek (#14824) Signed-off-by: billishyahao Co-authored-by: Zhai Feiyue <80079571+ZhaiFeiyue@users.noreply.github.com> --- .../online_serving/disaggregated_prefill.sh | 13 +++- .../kv_connector/simple_connector.py | 77 +++++++++++++++---- vllm/model_executor/models/deepseek_v2.py | 1 + 3 files changed, 70 insertions(+), 21 deletions(-) diff --git a/examples/online_serving/disaggregated_prefill.sh b/examples/online_serving/disaggregated_prefill.sh index 2bb2824c..6925dc8a 100644 --- a/examples/online_serving/disaggregated_prefill.sh +++ b/examples/online_serving/disaggregated_prefill.sh @@ -8,6 +8,9 @@ set -xe echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧" sleep 1 +# meta-llama/Meta-Llama-3.1-8B-Instruct or deepseek-ai/DeepSeek-V2-Lite +MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct} + # Trap the SIGINT signal (triggered by Ctrl+C) trap 'cleanup' INT @@ -44,18 +47,20 @@ wait_for_server() { # You can also adjust --kv-ip and --kv-port for distributed inference. # prefilling instance, which is the KV producer -CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ +CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ --port 8100 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ + --trust-remote-code \ --kv-transfer-config \ '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & # decoding instance, which is the KV consumer -CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ +CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ --port 8200 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ + --trust-remote-code \ --kv-transfer-config \ '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & @@ -78,7 +83,7 @@ sleep 1 output1=$(curl -X POST -s http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ -"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"model": "'"$MODEL_NAME"'", "prompt": "San Francisco is a", "max_tokens": 10, "temperature": 0 @@ -87,7 +92,7 @@ output1=$(curl -X POST -s http://localhost:8000/v1/completions \ output2=$(curl -X POST -s http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ -"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"model": "'"$MODEL_NAME"'", "prompt": "Santa Clara is a", "max_tokens": 10, "temperature": 0 diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 7315a6f4..49b97d7b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase @@ -37,6 +38,8 @@ class SimpleConnector(KVConnectorBase): self.config = config.kv_transfer_config self.tp_size = config.parallel_config.tensor_parallel_size + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE if self.config.kv_connector == "PyNcclConnector": from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( @@ -167,8 +170,26 @@ class SimpleConnector(KVConnectorBase): num_heads = int(model_config.num_key_value_heads / self.tp_size) hidden_size = model_config.hidden_size num_attention_heads = model_config.num_attention_heads - head_size = getattr(model_config, "head_dim", - int(hidden_size // num_attention_heads)) + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", + int(hidden_size // num_attention_heads)) # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance @@ -192,8 +213,12 @@ class SimpleConnector(KVConnectorBase): for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) current_slot_mapping = slot_mapping_flat[start_pos:end_pos] @@ -223,6 +248,8 @@ class SimpleConnector(KVConnectorBase): # and hidden states. bypass_model_exec = True + model_config = model_executable.model.config + input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens @@ -291,19 +318,35 @@ class SimpleConnector(KVConnectorBase): kv_cache = kv_caches[i - model_executable.model.start_layer] layer = model_executable.model.layers[i] - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - keys[i - model_executable.model.start_layer].to( - key_cache.device), - values[i - model_executable.model.start_layer].to( - value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys[ + i - model_executable.model.start_layer].to( + kv_cache.device).squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed, + k_pe, + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) hidden_or_intermediate_states_for_one_req.append(hidden) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d66f61a8..fcab533e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -589,6 +589,7 @@ class DeepseekV2Model(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.config = config self.vocab_size = config.vocab_size