[Core] Implement disagg prefill by StatelessProcessGroup (#10502)

This PR provides initial support for single-node disaggregated prefill in 1P1D scenario.
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
This commit is contained in:
Kuntai Du 2024-12-01 19:01:00 -06:00 committed by GitHub
parent c11f172187
commit 0590ec3fd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 2525 additions and 21 deletions

View File

@ -430,6 +430,9 @@ steps:
- vllm/model_executor/models/
- tests/distributed/
- vllm/compilation
- vllm/worker/worker_base.py
- vllm/worker/worker.py
- vllm/worker/model_runner.py
commands:
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
@ -443,6 +446,7 @@ steps:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"

View File

@ -0,0 +1,144 @@
#!/bin/bash
# benchmark the overhead of disaggregated prefill.
# methodology:
# - send all request to prefill vLLM instance. It will buffer KV cache.
# - then send all request to decode instance.
# - The TTFT of decode instance is the overhead.
set -ex
kill_gpu_processes() {
# kill all processes on GPU.
pkill -f pt_main_thread
sleep 10
# remove vllm config file
rm -rf ~/.config/vllm
# Print the GPU memory usage
# so that we know if all GPU processes are killed.
gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0)
# The memory usage should be 0 MB.
echo "GPU 0 Memory Usage: $gpu_memory_usage MB"
}
wait_for_server() {
# wait for vllm server to start
# return 1 if vllm server crashes
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
benchmark() {
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# compare chunked prefill with disaggregated prefill
results_folder="./results"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=10
qps=$1
prefix_len=50
input_len=2048
output_len=$2
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
wait_for_server 8100
wait_for_server 8200
# let the prefill instance finish prefill
python3 ../benchmark_serving.py \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8100 \
--save-result \
--result-dir $results_folder \
--result-filename disagg_prefill_2xtp4.json \
--request-rate "inf"
# send the request to decode.
# The TTFT of this command will be the overhead of disagg prefill impl.
python3 ../benchmark_serving.py \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8200 \
--save-result \
--result-dir $results_folder \
--result-filename disagg_prefill_2xtp4.json \
--request-rate "$qps"
kill_gpu_processes
}
main() {
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
pip install quart httpx
cd "$(dirname "$0")"
cd ..
# create sonnet-4x.txt
echo "" > sonnet_4x.txt
for _ in {1..4}
do
cat sonnet.txt >> sonnet_4x.txt
done
cd disagg_benchmarks
rm -rf results
mkdir results
default_qps=1
default_output_len=1
benchmark $default_qps $default_output_len
}
main "$@"

View File

@ -0,0 +1,164 @@
#!/bin/bash
# Requirement: 8x H100 GPUs.
# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV
# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests
# Resource: 8x H100
# Approaches:
# 1. Chunked prefill: 1 vllm instance with tp=8
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
# Prefilling instance: max_output_token=1
# Decoding instance: force the input tokens be the same across requests to bypass prefilling
set -ex
kill_gpu_processes() {
# kill all processes on GPU.
pgrep pt_main_thread | xargs -r kill -9
pgrep python3 | xargs -r kill -9
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
sleep 1
}
wait_for_server() {
# wait for vllm server to start
# return 1 if vllm server crashes
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
launch_chunked_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8100 \
--max-model-len 10000 \
--enable-chunked-prefill \
--gpu-memory-utilization 0.6 &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8200 \
--max-model-len 10000 \
--enable-chunked-prefill \
--gpu-memory-utilization 0.6 &
wait_for_server 8100
wait_for_server 8200
python3 round_robin_proxy.py &
sleep 1
}
launch_disagg_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
wait_for_server 8100
wait_for_server 8200
python3 disagg_prefill_proxy_server.py &
sleep 1
}
benchmark() {
results_folder="./results"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=100
qps=$1
prefix_len=50
input_len=1024
output_len=$2
tag=$3
python3 ../benchmark_serving.py \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8000 \
--save-result \
--result-dir $results_folder \
--result-filename "$tag"-qps-"$qps".json \
--request-rate "$qps"
sleep 2
}
main() {
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
pip install quart httpx matplotlib aiohttp
cd "$(dirname "$0")"
cd ..
# create sonnet-4x.txt so that we can sample 2048 tokens for input
echo "" > sonnet_4x.txt
for _ in {1..4}
do
cat sonnet.txt >> sonnet_4x.txt
done
cd disagg_benchmarks
rm -rf results
mkdir results
default_output_len=6
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
launch_chunked_prefill
for qps in 2 4 6 8; do
benchmark $qps $default_output_len chunked_prefill
done
kill_gpu_processes
launch_disagg_prefill
for qps in 2 4 6 8; do
benchmark $qps $default_output_len disagg_prefill
done
kill_gpu_processes
python3 visualize_benchmark_results.py
}
main "$@"

View File

@ -0,0 +1,61 @@
import os
import aiohttp
from quart import Quart, make_response, request
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
async def forward_request(url, data):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked':
if True:
async for chunk_bytes in response.content.iter_chunked(
1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route('/v1/completions', methods=['POST'])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1
# finish prefill
async for _ in forward_request('http://localhost:8100/v1/completions',
prefill_request):
continue
# return decode
generator = forward_request('http://localhost:8200/v1/completions',
original_request_data)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == '__main__':
app.run(port=8000)

View File

@ -0,0 +1,60 @@
import asyncio
import itertools
import aiohttp
from aiohttp import web
class RoundRobinProxy:
def __init__(self, target_ports):
self.target_ports = target_ports
self.port_cycle = itertools.cycle(self.target_ports)
async def handle_request(self, request):
target_port = next(self.port_cycle)
target_url = f"http://localhost:{target_port}{request.path_qs}"
async with aiohttp.ClientSession() as session:
try:
# Forward the request
async with session.request(
method=request.method,
url=target_url,
headers=request.headers,
data=request.content,
) as response:
# Start sending the response
resp = web.StreamResponse(status=response.status,
headers=response.headers)
await resp.prepare(request)
# Stream the response content
async for chunk in response.content.iter_any():
await resp.write(chunk)
await resp.write_eof()
return resp
except Exception as e:
return web.Response(text=f"Error: {str(e)}", status=500)
async def main():
proxy = RoundRobinProxy([8100, 8200])
app = web.Application()
app.router.add_route('*', '/{path:.*}', proxy.handle_request)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, 'localhost', 8000)
await site.start()
print("Proxy server started on http://localhost:8000")
# Keep the server running
await asyncio.Event().wait()
if __name__ == '__main__':
asyncio.run(main())

View File

@ -0,0 +1,46 @@
import json
import matplotlib.pyplot as plt
import pandas as pd
if __name__ == "__main__":
data = []
for name in ['disagg_prefill', 'chunked_prefill']:
for qps in [2, 4, 6, 8]:
with open(f"results/{name}-qps-{qps}.json") as f:
x = json.load(f)
x['name'] = name
x['qps'] = qps
data.append(x)
df = pd.DataFrame.from_dict(data)
dis_df = df[df['name'] == 'disagg_prefill']
chu_df = df[df['name'] == 'chunked_prefill']
plt.style.use('bmh')
plt.rcParams['font.size'] = 20
for key in [
'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms',
'median_itl_ms', 'p99_itl_ms'
]:
fig, ax = plt.subplots(figsize=(11, 7))
plt.plot(dis_df['qps'],
dis_df[key],
label='disagg_prefill',
marker='o',
linewidth=4)
plt.plot(chu_df['qps'],
chu_df[key],
label='chunked_prefill',
marker='o',
linewidth=4)
ax.legend()
ax.set_xlabel('QPS')
ax.set_ylabel(key)
ax.set_ylim(bottom=0)
fig.savefig(f'results/{key}.png')
plt.close(fig)

View File

@ -0,0 +1,109 @@
#!/bin/bash
# This file demonstrates the example usage of disaggregated prefilling
# We will launch 2 vllm instances (1 for prefill and 1 for decode),
# and then transfer the KV cache between them.
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
sleep 1
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then
echo "Quart is already installed."
else
echo "Quart is not installed. Installing..."
python3 -m pip install quart
fi
# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# You can also adjust --kv-ip and --kv-port for distributed inference.
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8100 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8200 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
# wait until prefill and decode instances are ready
wait_for_server 8100
wait_for_server 8200
# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - send the request to prefill vLLM instance (port 8100), change max_tokens
# to 1
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
# instance
# NOTE: the usage of this API is subject to change --- in the future we will
# introduce "vllm connect" to connect between prefill and decode instances
python3 ../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py &
sleep 1
# serve two example requests
output1=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": "San Francisco is a",
"max_tokens": 10,
"temperature": 0
}')
output2=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": "Santa Clara is a",
"max_tokens": 10,
"temperature": 0
}')
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo ""
sleep 1
# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""

View File

@ -0,0 +1,119 @@
import os
import subprocess
import sys
import time
from subprocess import Popen
import pytest
import requests
import torch
# Fixture to set up environment variables and teardown servers after tests
@pytest.fixture(scope="module", autouse=True)
def setup_servers():
if torch.cuda.device_count() < 4:
pytest.skip("Skipping test: fewer than 4 GPUs available")
# Set up environment variables
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
shell=True).decode().strip()
os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP
# Start prefill instance
prefill_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"--port",
"8100",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\
'"kv_rank":0,"kv_parallel_size":2}',
]
prefill_env = os.environ.copy()
prefill_env["CUDA_VISIBLE_DEVICES"] = "0"
prefill_proc = Popen(prefill_cmd, env=prefill_env)
# Start decode instance
decode_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"--port",
"8200",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\
'"kv_rank":1,"kv_parallel_size":2}',
]
decode_env = os.environ.copy()
decode_env["CUDA_VISIBLE_DEVICES"] = "1"
decode_proc = Popen(decode_cmd, env=decode_env)
# Wait for servers to be ready
assert wait_for_server(8100), "Prefill server did not start in time"
assert wait_for_server(8200), "Decode server did not start in time"
# Yield to the test function and handle teardown after tests
yield
# Cleanup: kill the processes
prefill_proc.terminate()
decode_proc.terminate()
# Additional cleanup if needed
prefill_proc.wait()
decode_proc.wait()
# Helper function to wait for server
def wait_for_server(port, timeout=240):
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"http://localhost:{port}/v1/completions")
if response.status_code in [200, 405]:
return True
except requests.ConnectionError:
time.sleep(1)
return False
# Test function to send curl requests and validate responses
@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"])
def test_disaggregated_prefilling(prompt):
# Send to prefill
response = requests.post("http://localhost:8100/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model":
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": prompt,
"max_tokens": 1,
"temperature": 0
})
assert response.status_code == 200
# Send to decode
response = requests.post("http://localhost:8200/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model":
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": prompt,
"max_tokens": 10,
"temperature": 0
})
assert response.status_code == 200

View File

@ -0,0 +1,64 @@
import subprocess
import sys
import pytest
import torch
def run_python_script(script_name, timeout):
script_name = f'kv_transfer/{script_name}'
try:
# Start both processes asynchronously using Popen
process0 = subprocess.Popen(
[sys.executable, script_name],
env={"RANK":
"0"}, # Set the RANK environment variable for process 0
stdout=sys.stdout, # Pipe stdout to current stdout
stderr=sys.stderr, # Pipe stderr to current stderr
)
process1 = subprocess.Popen(
[sys.executable, script_name],
env={"RANK":
"1"}, # Set the RANK environment variable for process 1
stdout=sys.stdout, # Pipe stdout to current stdout
stderr=sys.stderr, # Pipe stderr to current stderr
)
# Wait for both processes to complete, with a timeout
process0.wait(timeout=timeout)
process1.wait(timeout=timeout)
# Check the return status of both processes
if process0.returncode != 0:
pytest.fail(
f"Test {script_name} failed for RANK=0, {process0.returncode}")
if process1.returncode != 0:
pytest.fail(
f"Test {script_name} failed for RANK=1, {process1.returncode}")
except subprocess.TimeoutExpired:
# If either process times out, terminate both and fail the test
process0.terminate()
process1.terminate()
pytest.fail(f"Test {script_name} timed out")
except Exception as e:
pytest.fail(f"Test {script_name} failed with error: {str(e)}")
# Define the test cases using pytest's parametrize
@pytest.mark.parametrize(
"script_name,timeout",
[
("test_lookup_buffer.py",
60), # Second test case with a 60-second timeout
("test_send_recv.py", 120) # First test case with a 120-second timeout
])
def test_run_python_script(script_name, timeout):
# Check the number of GPUs
if torch.cuda.device_count() < 2:
pytest.skip(
f"Skipping test {script_name} because <2 GPUs are available")
# Run the test if there are at least 2 GPUs
run_python_script(script_name, timeout)

View File

@ -0,0 +1,160 @@
import os
import random
import torch
from tqdm import tqdm
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
# TODO: the test depends on a lot of fields in the current implementation.
# We should have standard interface instead direct field access
def test_run(my_rank, buffer, device):
# buffer should be empty in the beginning
if my_rank == 0:
assert buffer.buffer_size == 0
assert len(buffer.buffer) == 0
print("My rank: %d, device: %s" % (my_rank, device))
# insert
tokens = torch.tensor([1, 2, 3]).to(device)
roi = (tokens > 0)
if my_rank == 0:
key = 2.0 * torch.ones([5, 6]).to(device)
value = 3.0 * torch.ones([5, 6]).to(device)
placeholder = torch.tensor([1]).to(device)
buffer.insert(tokens, roi, key, value, placeholder)
torch.distributed.barrier()
# drop_select
if my_rank == 1:
tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi)
assert torch.allclose(tokens, tok)
assert torch.allclose(roi, roi_)
assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device))
assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device))
torch.distributed.barrier()
if my_rank == 0:
assert buffer.buffer_size == 0
assert len(buffer.buffer) == 0
print("Test run passed!")
def stress_test(my_rank, buf, device):
torch.distributed.barrier()
torch.manual_seed(100)
reqs = [
(
torch.rand(100).to(device), # tokens
torch.ones(100).bool().to(device), # roi
torch.rand(100).to(device), # key
torch.rand(100).to(device), # value
torch.rand(100).to(device), # hidden
) for i in tqdm(range(200))
]
random.seed(my_rank)
random.shuffle(reqs)
torch.distributed.barrier()
n = 0
# the buffer size can only store 100 reqs
# so the sender will occasionally block to wait for the receiver.
for req in tqdm(reqs):
if my_rank == 0:
buf.insert(*req)
else:
tok, roi, k, v, h = req
tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi)
if tok_ is None:
assert roi_ is None
assert k_ is None
assert v_ is None
assert h_ is None
n += 1
else:
assert torch.allclose(tok, tok_)
assert torch.allclose(roi, roi_)
assert torch.allclose(k, k_)
assert torch.allclose(v, v_)
assert torch.allclose(h, h_)
print('Rank %d done' % my_rank)
torch.distributed.barrier()
if my_rank == 0:
x = torch.tensor([0])
torch.distributed.recv(x, 1)
# the # of None received is the kv that are not selected
assert x.item() == len(buf.buffer)
# and the size of the buffer should be 2000 * buffer len
print(buf.buffer_size)
assert buf.buffer_size == 1700 * len(buf.buffer)
else:
torch.distributed.send(torch.tensor([n]), 0)
print("Passed stress test!")
if __name__ == "__main__":
my_rank = int(os.environ['RANK'])
torch.distributed.init_process_group(
backend='gloo',
init_method='tcp://localhost:12398',
world_size=2,
rank=my_rank,
)
print("initialized! My rank is %d" % my_rank)
config = KVTransferConfig(
kv_connector='PyNcclConnector',
kv_buffer_device='cuda',
kv_buffer_size=1e9,
kv_rank=my_rank,
kv_role="kv_both", # this arg doesn't matter in this test
kv_parallel_size=2,
kv_ip="127.0.0.1",
kv_port=12345,
)
data_pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
device="cuda",
port_offset=0,
)
cpu_pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
device="cpu",
port_offset=1,
)
buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000)
test_run(my_rank, buffer, data_pipe.device)
stress_test(my_rank, buffer, data_pipe.device)
buffer.close()
data_pipe.close()
cpu_pipe.close()
print('Done')

View File

@ -0,0 +1,3 @@
#!/bin/bash
RANK=0 python test_lookup_buffer.py &
RANK=1 python test_lookup_buffer.py &

View File

@ -0,0 +1,155 @@
import os
import time
from typing import List
import torch
from tqdm import tqdm
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
def test_run(my_rank, pipe):
# test run
x = torch.tensor([1]).to(pipe.device)
y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device)
if my_rank == 0:
pipe.send_tensor(x)
print("sent tensor x")
pipe.send_tensor(y)
print("sent tensor y")
x2 = pipe.recv_tensor()
print("received x2 = ", x2)
y2 = pipe.recv_tensor()
print("received y2 = ", x2)
else:
x2 = pipe.recv_tensor()
print("received x2 = ", x2)
y2 = pipe.recv_tensor()
print("received y2 = ", x2)
pipe.send_tensor(x)
print("sent tensor x")
pipe.send_tensor(y)
print("sent tensor y")
assert torch.allclose(x, x2)
assert torch.allclose(y, y2)
def stress_test(my_rank, pipe):
torch.distributed.barrier()
tensors: List[torch.Tensor] = []
torch.manual_seed(0)
for i in tqdm(range(500)):
mean = torch.rand(1).item() * 100
std = torch.rand(1).item() * 100
size = torch.randint(900, 1000, (2, ))
x = torch.normal(mean * 1.0, std * 1.0,
size=size.tolist()).to(pipe.device)
# 5% probability of sending a None
if torch.rand(1).item() < 0.05:
tensors.append(None)
tensors.append(None)
tensors.append(None)
else:
tensors.append(x)
tensors.append(x.mean().unsqueeze(0))
tensors.append(x.std().unsqueeze(0))
torch.distributed.barrier()
for i in tqdm(range(500)):
if my_rank == int((i % 10) > 3):
pipe.send_tensor(tensors[3 * i])
pipe.send_tensor(tensors[3 * i + 1])
pipe.send_tensor(tensors[3 * i + 2])
else:
x = pipe.recv_tensor()
mean = pipe.recv_tensor()
std = pipe.recv_tensor()
if x is None:
assert mean is None
assert std is None
else:
assert torch.allclose(x, tensors[3 * i])
assert x.mean() == mean[0]
assert x.std() == std[0]
torch.distributed.barrier()
def latency_test(my_rank, pipe, nelement, ntensor):
latencies = []
torch.distributed.barrier()
for i in tqdm(range(500)):
tensors = []
if my_rank == 0:
# create tensor
tensors = [
torch.rand(nelement).to(pipe.device) for _ in range(ntensor)
]
torch.distributed.barrier()
if my_rank == 0:
t = torch.tensor([time.time()],
dtype=torch.float64).to(pipe.device)
for tensor in tensors:
pipe.send_tensor(tensor)
pipe.send_tensor(t)
else:
for _ in range(ntensor):
pipe.recv_tensor()
t = pipe.recv_tensor()
latencies.append(time.time() - t.item())
torch.distributed.barrier()
print('Latency test passed.')
print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms')
if __name__ == "__main__":
my_rank = int(os.environ['RANK'])
torch.distributed.init_process_group(
backend='gloo',
init_method='tcp://localhost:12398',
world_size=2,
rank=my_rank,
)
config = KVTransferConfig(
kv_connector='PyNcclConnector',
kv_buffer_device='cuda',
kv_buffer_size=1e9,
kv_rank=my_rank,
kv_role="kv_both", # this arg doesn't matter in this test
kv_parallel_size=2,
kv_ip="127.0.0.1",
kv_port=12345,
)
pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
)
test_run(my_rank, pipe)
stress_test(my_rank, pipe)
# Use this function if you want to test the latency of pipe impl.
# latency_test(my_rank, pipe, 1024 * 8 * 128, 80)

View File

@ -0,0 +1,3 @@
#!/bin/bash
RANK=0 python3 test_send_recv.py &
RANK=1 python3 test_send_recv.py &

View File

@ -2052,6 +2052,88 @@ class ObservabilityConfig:
f"installed. Original error:\n{otel_import_error_traceback}")
class KVTransferConfig(BaseModel):
"""Configuration for distributed KV cache transfer."""
# The KV connector for vLLM to transmit KV caches between vLLM instances.
kv_connector: Optional[str] = None
# The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'.
kv_buffer_device: Optional[str] = "cuda"
# The buffer size for TorchDistributedConnector. Measured in number of
# bytes. Recommended value: 1e9 (about 1GB).
kv_buffer_size: float = 1e9
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
# are 'kv_producer', 'kv_consumer', and 'both'.
kv_role: Optional[str] = None
# The rank of this vLLM instance in the KV cache transfer. Typical value:
# 0 for prefill instance, 1 for decode instance.
# Currently only 1P1D is supported.
kv_rank: Optional[int] = None
# The number of parallel instances for KV cache transfer. For
# PyNcclConnector, this should be 2.
kv_parallel_size: int = 1
# The KV connector ip, used to build distributed connection
kv_ip: str = "127.0.0.1"
# The KV connector port, used to build distributed connection
kv_port: int = 14579
@classmethod
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
"""Parse the CLI value for the compilation config."""
return KVTransferConfig.model_validate_json(cli_value)
def model_post_init(self, __context: Any) -> None:
if all([
self.kv_connector is not None,
self.kv_connector != "PyNcclConnector"
]):
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
f"Supported connectors are "
f"`PyNcclConnector`.")
if self.kv_role is not None and self.kv_role not in [
"kv_producer", "kv_consumer", "kv_both"
]:
raise ValueError(
f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`")
if self.kv_connector is not None and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]
@property
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
return self.kv_connector is not None and self.kv_parallel_size > 1
@property
def is_kv_producer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_producer", "kv_both"]
@property
def is_kv_consumer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"]
class CompilationLevel:
# constants for the levels of the compilation process
NO_COMPILATION = 0
@ -2317,6 +2399,8 @@ class VllmConfig:
quant_config: Optional[QuantizationConfig] = None
compilation_config: CompilationConfig = field(default=None,
init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
@staticmethod
def _get_quantization_config(

View File

@ -0,0 +1,30 @@
# Distributed KV cache transfer
This folder implements distributed KV cache transfer across vLLM instances.
Currently the main usecase is for disaggregated prefilling.
## Abstractions
The KV cache transfer contains three layer of abstractions:
- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`.
- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics).
- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`.
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or
RDMA database).
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
## Disaggregated prefilling
The example usage is in [this file](../../../examples/disaggregated_prefill.sh).
Here is the diagram of how we run disaggretgated prefilling.
![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg)

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 139 KiB

View File

@ -0,0 +1,122 @@
"""
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union
import torch
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class KVConnectorBase(ABC):
"""
Abstract base class for a KV connector.
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
@abstractmethod
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
connector when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
"""
Send KV caches and hidden states to the connector.
This method processes the input tokens, KV caches, and
hidden/intermediate states for a given model and sends the data to the
decode instance.
Args:
model_executable (torch.nn.Module): The model executable containing
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
The hidden or intermediate states associated with the tokens.
Returns:
None
"""
raise NotImplementedError
@abstractmethod
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
"""
Receive KV caches and hidden states from the connector.
This method attempts to retrieve KV caches and hidden states for input
tokens. If all required KV caches and hidden states are received, it
will bypass model input, else it will fall back to normal vLLM model
forwarding.
Args:
model_executable (torch.nn.Module):
The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner.
kv_caches (List[torch.Tensor]):
List of KV caches for each layer.
Returns:
- hidden_or_intermediate_states (torch.Tensor or
IntermediateTensors):
Concatenated hidden states if all required data is retrieved,
otherwise `None`.
- bypass_model_exec (bool):
Indicates whether the model execution can be skipped (True) or
needs to be redone (False).
- model_input (ModelInputForGPUWithSamplingMetadata):
Optionally adjusted input metadata for re-execution when
`bypass_model_exec=False`.
"""
raise NotImplementedError

View File

@ -0,0 +1,19 @@
from typing import TYPE_CHECKING
from .base import KVConnectorBase
if TYPE_CHECKING:
from vllm.config import VllmConfig
class KVConnectorFactory:
@staticmethod
def create_connector(rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
if config.kv_transfer_config.kv_connector == 'PyNcclConnector':
from .simple_connector import SimpleConnector
return SimpleConnector(rank, local_rank, config)
else:
raise ValueError(f"Unsupported connector type: "
f"{config.kv_connector}")

View File

@ -0,0 +1,261 @@
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
class SimpleConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.config = config.kv_transfer_config
logger.info("Initializing PyNcclConfig under kv_transfer_config %s",
self.config)
self.lookup_buffer_size = self.config.kv_buffer_size
self.producer_buffer: Optional[SimpleBuffer] = None
self.consumer_buffer: Optional[SimpleBuffer] = None
# 2 pipes for every rank in the world
port_offset_base = 2 * rank
# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:
self.producer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
self.producer_data_pipe,
self.config.kv_buffer_size)
else:
# the current vLLM instance is KV consumer, so it needs to connect
# its recv pipe to the send pipe of KV producder
self.consumer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
self.consumer_buffer = SimpleBuffer(
self.consumer_signal_pipe,
self.consumer_data_pipe,
self.config.kv_buffer_size,
)
def select(self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
return self.consumer_buffer.drop_select(input_tokens, roi)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
_, _, num_heads, head_size = kv_cache[0].shape
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
self.insert(current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
# When bypass_model_exec is set to False, it means that at least for one
# request its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states.
bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
ret = self.select(current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
bypass_model_exec = False
num_computed_tokens_list.append(0)
continue
roi: torch.Tensor = ret[1]
keys: torch.Tensor = ret[2]
values: torch.Tensor = ret[3]
hidden: torch.Tensor = ret[4]
num_computed_tokens = roi.shape[0]
num_computed_tokens_list.append(num_computed_tokens)
# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False
# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.debug(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self):
self.producer_data_pipe.close()
self.producer_signal_pipe.close()
self.consumer_data_pipe.close()
self.consumer_signal_pipe.close()

View File

@ -0,0 +1,108 @@
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.
All distributed communications are abstracted behind this class.
"""
from abc import ABC, abstractmethod
from typing import List, Optional
import torch
class KVLookupBufferBase(ABC):
"""
Abstract base class for a lookup buffer.
This class provides an abstraction for a key-value (KV) cache lookup buffer.
The key of the lookup buffer:
- input_tokens: token IDs of the request
- roi: a binary mask on top of input_tokens.
- Purpose of roi: Since KV cache may only be available for a subset of
tokens in the input (for example, when vLLM is connected to an external
KV cache service), roi specifies the subset of tokens that the KV cache
is associated with.
- NOTE: roi can be further extended to describe which part of KV the
current process is holding (each process may only hold a part of KV
due to TP and PP). This is not implemented for now.
The value of the lookup buffer:
- key: the key tensor in the KV cache
- value: the value tensor in the KV cache
- hidden: the final hidden state generated by model forwarding. This allows
vLLM to bypass further model forwarding by transmitting the hidden state.
"""
@abstractmethod
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
"""Insert into the lookup buffer.
The functionality is similar to the following python statement
```
buffer[input_tokens, roi] = [key, value, hidden]
```
FIXME: in the future, we should only have two arguments, key and value,
where key is a tensor dict and value is a tensor dict.
FIXME: we should transmit both sampler outputs and the hidden states.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
key (torch.Tensor): The key tensor in the KV cache.
value (torch.Tensor): The value tensor in the KV cache.
hidden (torch.Tensor): The final hidden state tensor generated
during model forwarding to bypass model
forwarding.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
If `input_tokens` and `roi` is `None`, it means selecting any of the
KV caches in the buffer, return, and remove it from the buffer, useful
when offloading KV cache to KV cache storage service.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
List[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
lookup buffer when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@ -0,0 +1,242 @@
"""
Implements a distributed key-value (KV) cache transfer mechanism.
Key Features:
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading
import time
from collections import deque
from typing import Deque, List, Optional, Union
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVLookupBufferBase)
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class SimpleBuffer(KVLookupBufferBase):
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
buffer_size_thresh: float):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
self.buffer: Deque[List[torch.Tensor]] = deque()
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_lock = threading.Lock()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
def _matches(self, tokens_roi_sender: List[torch.Tensor],
tokens_roi_recver: List[torch.Tensor]):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender = tokens_roi_sender[0]
tokens_recver = tokens_roi_recver[0]
roi_sender = tokens_roi_sender[1]
roi_recver = tokens_roi_recver[1]
if tokens_recver is None:
# consumer sends an empty request
# semantics: DROP SELECT * LIMIT 1
# so any of the data in the buffer can be drop-selected
return True
# Assuming that roi is a binary mask on tokens
tokens_sender = tokens_sender[roi_sender]
tokens_recver = tokens_recver[roi_recver]
# simple common prefix matching
min_length = min(len(tokens_sender), len(tokens_recver))
if torch.allclose(tokens_sender[:min_length],
tokens_recver[:min_length]):
return min_length
return 0
def _send_tensor_and_dec_size(self,
tensor: Optional[torch.Tensor]) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
if isinstance(data, torch.Tensor):
return data.element_size() * data.numel()
if not data:
# cannot perform `not data` on a tensor
# so this check needs to go after the check above
return 0
raise AssertionError(f"Unknown data type {type(data)}")
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor):
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone()
if isinstance(key, torch.Tensor):
key = key.clone()
if isinstance(value, torch.Tensor):
value = value.clone()
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
with self.buffer_lock:
for data in buffer_item:
self.buffer_size += self._get_element_size(data)
self.buffer.append(buffer_item)
def _is_end_signal(self, signal):
return signal is None
def drop_select_handler(self):
try:
while True:
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
assert roi is not None, "Please provide the roi when sending "\
"drop-select request"
roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi]
matched_length = 0
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
with self.buffer_lock:
for _ in range(len(self.buffer)):
temp_length = self._matches(self.buffer[0],
tokens_roi_recver)
if temp_length > 0:
matched_length = temp_length
break
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
if matched_length > 0:
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
else:
# no match, just send None
for _ in range(5):
self.data_pipe.send_tensor(None)
except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
raise e
logger.debug("Closing drop_select_handler")
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
assert self.request_handling_thread is None, \
"drop_select should be called by the KV cache consumer "\
"(e.g. the decode vLLM instance)"
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = (roi > 0.5)
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
if self.buffer_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
if self.request_handling_thread is None:
self.request_handling_thread = threading.Thread(
target=self.drop_select_handler)
self.request_handling_thread.start()
def close(self):
if hasattr(self, "request_handling_thread"
) and self.request_handling_thread is not None:
self.request_handling_thread.join()
else:
# TODO: have a explicit close signal and have a explicit way to
# check if it's requester
self.signal_pipe.send_tensor(self.end_signal)

View File

@ -0,0 +1,65 @@
"""
This file defines an interface `KVPipeBase`
that provides an abstraction for sending and receiving tensors, or None, via
distributed communications.
All classes instantiated from this interface are assumed to be a FIFO pipe.
If your distributed communication platform already supports key-value lookup,
you can bypass this interface and directly start from `kv_lookup_buffer`.
"""
from abc import ABC, abstractmethod
from typing import Optional
import torch
class KVPipeBase(ABC):
"""
This class provides an interface for sending and receiving tensors, or
None, by distributed communications.
"""
@abstractmethod
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
TODO: add a `key` argument so that we can use traditional
key-value database as the distributed communication mechanism behind
the pipe.
Args:
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def recv_tensor(self) -> Optional[torch.Tensor]:
"""Receive a tensor (can be None) from the pipeline.
Returns:
Optional[torch.Tensor]: The tensor received from the pipeline. Can
be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the pipeline and release resources.
This method is responsible for closing the communication pipeline
and releasing any resources associated with it.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@ -0,0 +1,276 @@
"""
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
communication features.
Key Features:
- Supports sending and receiving tensors with metadata
- Handles both CUDA and CPU device communications
- Implements a non-blocking tensor transfer mechanism
- Manages buffer size and provides backpressure control
- Supports distributed process groups with configurable parameters
"""
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Optional, Tuple
import torch
from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
logger = init_logger(__name__)
class BrokenPipeException(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
Metadata = Dict[str, Optional[torch.Tensor]]
class PyNcclPipe(KVPipeBase):
METADATA_LENGTH = 16
MAX_TENSOR_DIMENSIONS = 14
METADATA_DTYPE = torch.int64
def __init__(self,
local_rank: int,
config: KVTransferConfig,
device: Optional[str] = None,
port_offset: int = 0):
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
self.kv_parallel_size = self.config.kv_parallel_size
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
# build distributed connection and send/recv implementation
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
rank=self.kv_rank,
world_size=self.kv_parallel_size,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
impl = self._get_device_send_recv_impl(self.group)
self.device_send_func, self.device_recv_func = impl
# set target rank
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
# transportation-related variables
self.transport_thread: Optional[ThreadPoolExecutor] = None
self.buffer_size = 0
self.buffer_size_lock = threading.Lock()
self.buffer_size_thresh = self.config.kv_buffer_size
def _get_device_send_recv_impl(
self, group: StatelessProcessGroup
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
[torch.Tensor, int], None]]:
send: Callable[[torch.Tensor, int], None]
recv: Callable[[torch.Tensor, int], None]
if self.device.type == "cuda":
# use PyNCCL for send / recv
comm = PyNcclCommunicator(group, device=self.local_rank)
comm.disabled = False
send, recv = comm.send, comm.recv # type: ignore
else:
# This send / recv implementation here is NOT intended to transfer
# KV caches (and should NOT be repurposed to transfer KV caches).
# Currently it is only used to transmit control-plane messages
# for PyNcclBuffer.
send = group.send_obj
def my_recv(x, src):
x[...] = group.recv_obj(src)
recv = my_recv
return send, recv
def _select_device(self, device: str):
logger.info("Selecting device: %s", device)
if device == "cuda":
return torch.device(f"cuda:{self.local_rank}")
else:
return torch.device("cpu")
def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
"""
Create the metadata as a dictionary based on the input tensor.
Parameters:
- tensor: The input tensor or None if no tensor is provided.
Returns:
- metadata: A dictionary with the following keys:
- "dtype": The data type of the tensor or None.
- "shape": The shape of the tensor or None.
"""
if tensor is None:
return {"dtype": None, "shape": None}
else:
return {"dtype": tensor.dtype, "shape": tensor.shape}
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
"""
Create a buffer to receive the tensor based on the provided metadata.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape", describing
the tensor's data type and shape.
Returns:
- buffer: A tensor of the specified type and shape, allocated on
self.device.
"""
return torch.empty(metadata["shape"],
dtype=metadata["dtype"],
device=self.device)
def _send_metadata(self, metadata: Metadata):
"""
Send the metadata dictionary to the target rank.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape".
"""
self.group.send_obj(metadata, self.target_rank_for_send)
def _recv_metadata(self) -> Metadata:
"""
Receive the metadata dictionary from the target rank.
Returns:
- metadata: A dictionary with keys "dtype" and "shape" describing
the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
target rank.
Parameters:
- tensor: The input tensor to be sent, or None if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
self._send_metadata(metadata)
if tensor is not None:
self.device_send_func(tensor.to(self.device),
self.target_rank_for_send)
def _recv_impl(self) -> Optional[torch.Tensor]:
"""
The actual implementation of receiving a tensor and its metadata from
the target rank.
Returns:
- buffer: The received tensor, or None if no tensor is received.
"""
metadata = self._recv_metadata()
if metadata["dtype"] is None:
return None
buffer = self._prepare_recv_buffer(metadata)
self.device_recv_func(buffer, self.target_rank_for_recv)
return buffer
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
tensor_size: int) -> None:
"""
Wrapper for _send_impl to handle exceptions and update buffer size.
"""
try:
self._send_impl(tensor)
with self.buffer_size_lock:
self.buffer_size -= tensor_size
except Exception as e:
logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
torch.distributed.get_rank(), str(tensor), str(e))
import traceback
traceback.print_exc()
def block_if_full(self):
"""
Block the current thread if the buffer size is larger than the
threshold.
"""
while self.buffer_size > self.buffer_size_thresh:
logger.debug("KV cache transfer pipe is full. Waiting...")
time.sleep(0.05)
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
Parameters:
- tensor: The tensor to send, or None if no tensor is being sent.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
if tensor is not None:
tensor_size = tensor.element_size() * tensor.numel()
else:
tensor_size = 0
self.block_if_full()
with self.buffer_size_lock:
self.buffer_size += tensor_size
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
tensor_size)
def recv_tensor(self) -> Optional[torch.Tensor]:
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Returns:
- tensor: The received tensor, or None if no tensor is received.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
future = self.transport_thread.submit(self._recv_impl)
try:
tensor = future.result()
except Exception as e:
logger.error("Encountering exception in KV receiving thread")
logger.error("%s", e)
logger.error("My device: %s", self.device)
import traceback
traceback.print_exc()
raise e
return tensor
def close(self):
"""
Close the pipe and release associated resources.
"""
if hasattr(self,
"transport_thread") and self.transport_thread is not None:
self.transport_thread.shutdown()

View File

@ -0,0 +1,75 @@
"""A centralized entrypoint to perform distributed KV cache transfer.
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
1. `send_kv_caches_and_hidden_states`
2. `recv_kv_caches_and_hidden_states
"""
from typing import TYPE_CHECKING, List, Tuple, Union
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.config import VllmConfig
import torch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
class KVTransferAgent:
"""
A class designated for distributed KV transfer
Target use cases:
1. Disaggregated prefill
2. Remote KV cache storage
"""
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
self.config = config
if config.kv_transfer_config is None:
raise ValueError("KVTransferConfig is not set in the VllmConfig,"
" cannot initialize KVConnector.")
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
"TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector(
rank, local_rank, config)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
self.connector.send_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches,
hidden_or_intermediate_states)
def close(self) -> None:
self.connector.close()
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
return self.connector.recv_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches)

View File

@ -27,18 +27,23 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
from unittest.mock import patch
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op
if TYPE_CHECKING:
from vllm.config import VllmConfig
@dataclass
class GraphCaptureContext:
@ -904,6 +909,14 @@ def get_pp_group() -> GroupCoordinator:
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None
def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
assert _KV_TRANSFER is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_TRANSFER
@contextmanager
def graph_capture():
@ -1052,6 +1065,26 @@ def initialize_model_parallel(
group_name="pp")
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_TRANSFER
if vllm_config.kv_transfer_config is None:
return
if all([
vllm_config.kv_transfer_config.need_kv_parallel_group,
_KV_TRANSFER is None
]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,

View File

@ -9,10 +9,10 @@ import torch
import vllm.envs as envs
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig, HfOverrides, LoadConfig,
LoadFormat, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PromptAdapterConfig, SchedulerConfig,
DecodingConfig, DeviceConfig, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
from vllm.executor.executor_base import ExecutorBase
@ -108,6 +108,7 @@ class EngineArgs:
# notice.
distributed_executor_backend: Optional[Union[str,
Type[ExecutorBase]]] = None
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
@ -194,6 +195,8 @@ class EngineArgs:
compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto"
kv_transfer_config: Optional[KVTransferConfig] = None
def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
@ -908,6 +911,12 @@ class EngineArgs:
'compilers, using -O without space is also '
'supported. -O3 is equivalent to -O 3.')
parser.add_argument('--kv-transfer-config',
type=KVTransferConfig.from_cli,
default=None,
help='The configurations for distributed KV cache '
'transfer. Should be a JSON string.')
parser.add_argument(
'--worker-cls',
type=str,
@ -1201,6 +1210,7 @@ class EngineArgs:
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
)
if envs.VLLM_USE_V1:

View File

@ -21,7 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
@ -1666,6 +1666,24 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
else:
model_executable = self.model
# Receive KV cache in distributed KV cache transfer setting
# In disagg prefill setting, it will also recv hidden states and bypass
# model forwarding
# In KV cache database setting, it will change the model input so that
# we can skip prefilling on tokens that successfully received KV caches
# NOTE: The receive operation is blocking
bypass_model_exec = False
if self.need_recv_kv(model_input, kv_caches):
hidden_or_intermediate_states, bypass_model_exec, model_input = \
get_kv_transfer_group().recv_kv_caches_and_hidden_states(
# model is used to know which layer the current worker
# is working on, so that we can receive KV for only those
# layers.
model_executable,
model_input,
kv_caches=kv_caches
)
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
@ -1677,21 +1695,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Sending KV cache in distributed KV cache transfer setting
# NOTE: the send operation is non-blocking
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
kv_caches,
hidden_or_intermediate_states,
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
@ -1759,6 +1792,56 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
return [output]
def need_recv_kv(self, model_input, kv_caches) -> bool:
"""Check if we need to receive kv-cache from the other worker.
We need to receive KV when
1. current vLLM instance is KV cache consumer/decode vLLM instance
2. this batch is not a profiling run
3. this batch is a prefill run
Args:
model_input: input to the model executable
kv_caches: vLLM's paged memory
"""
prefill_meta = model_input.attn_metadata.prefill_metadata
# check if the current run is profiling
is_profile_run = (kv_caches[0].numel() == 0)
# check if the current run is prefill
is_prefill_run = prefill_meta is not None
if self.vllm_config.kv_transfer_config is None:
return False
return self.vllm_config.kv_transfer_config.is_kv_consumer and (
not is_profile_run) and is_prefill_run
def need_send_kv(self, model_input, kv_caches) -> bool:
"""Check if we need to send kv-cache to the other worker.
We need to send KV when
1. current vLLM instance is KV cache producer/prefill vLLM instance
2. this batch is not a profiling run
3. this batch is a prefill run
Args:
model_input: input to the model executable
kv_caches: vLLM's paged memory
"""
prefill_meta = model_input.attn_metadata.prefill_metadata
# check if the current run is profiling
is_profile_run = (kv_caches[0].numel() == 0)
# check if the current run is prefill
is_prefill_run = prefill_meta is not None
if self.vllm_config.kv_transfer_config is None:
return False
return self.vllm_config.kv_transfer_config.is_kv_producer and (
not is_profile_run) and is_prefill_run
# NOTE: this is nn.Module so the profiler can properly capture/group
# kernels calls made within the graph

View File

@ -8,8 +8,9 @@ import torch
import torch.distributed
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.config import VllmConfig
from vllm.distributed import (ensure_kv_transfer_initialized,
ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
@ -144,7 +145,7 @@ class Worker(LocalOrDistributedWorkerBase):
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
@ -457,20 +458,22 @@ class Worker(LocalOrDistributedWorkerBase):
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
vllm_config: VllmConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.

View File

@ -43,6 +43,7 @@ class WorkerBase(ABC):
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
@abstractmethod
def init_device(self) -> None: