[Feature][Disaggregated] Support XpYd disaggregated prefill with MooncakeStore (#12957)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
parent
94744ba41a
commit
6fa7cd3dbc
450
examples/online_serving/disagg_examples/disagg_proxy_demo.py
Normal file
450
examples/online_serving/disagg_examples/disagg_proxy_demo.py
Normal file
@ -0,0 +1,450 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This file provides a disaggregated prefilling proxy demo to demonstrate an
|
||||
example usage of XpYd disaggregated prefilling.
|
||||
We can launch multiple vllm instances (2 for prefill and 2 for decode), and
|
||||
launch this proxy demo through:
|
||||
python3 examples/online_serving/disagg_examples/disagg_proxy_demo.py \
|
||||
--model $model_name \
|
||||
--prefill localhost:8100 localhost:8101 \
|
||||
--decode localhost:8200 localhost:8201 \
|
||||
--port 8000
|
||||
|
||||
Note: This demo will be removed once the PDController implemented in PR 15343
|
||||
(https://github.com/vllm-project/vllm/pull/15343) supports XpYd.
|
||||
"""
|
||||
import argparse
|
||||
import ipaddress
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import uvicorn
|
||||
from fastapi import (APIRouter, Depends, FastAPI, Header, HTTPException,
|
||||
Request, status)
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class SchedulingPolicy(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def schedule(self, cycler: itertools.cycle):
|
||||
raise NotImplementedError("Scheduling Proxy is not set.")
|
||||
|
||||
|
||||
class Proxy:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefill_instances: list[str],
|
||||
decode_instances: list[str],
|
||||
model: str,
|
||||
scheduling_policy: SchedulingPolicy,
|
||||
custom_create_completion: Optional[Callable[[Request],
|
||||
StreamingResponse]] = None,
|
||||
custom_create_chat_completion: Optional[Callable[
|
||||
[Request], StreamingResponse]] = None,
|
||||
):
|
||||
self.prefill_instances = prefill_instances
|
||||
self.decode_instances = decode_instances
|
||||
self.prefill_cycler = itertools.cycle(prefill_instances)
|
||||
self.decode_cycler = itertools.cycle(decode_instances)
|
||||
self.model = model
|
||||
self.scheduling_policy = scheduling_policy
|
||||
self.custom_create_completion = custom_create_completion
|
||||
self.custom_create_chat_completion = custom_create_chat_completion
|
||||
self.router = APIRouter()
|
||||
self.setup_routes()
|
||||
|
||||
def setup_routes(self):
|
||||
self.router.post(
|
||||
"/v1/completions",
|
||||
dependencies=[
|
||||
Depends(self.validate_json_request)
|
||||
])(self.custom_create_completion if self.
|
||||
custom_create_completion else self.create_completion)
|
||||
self.router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[
|
||||
Depends(self.validate_json_request)
|
||||
])(self.custom_create_chat_completion if self.
|
||||
custom_create_chat_completion else self.create_chat_completion)
|
||||
self.router.get("/status",
|
||||
response_class=JSONResponse)(self.get_status)
|
||||
self.router.post("/instances/add",
|
||||
dependencies=[Depends(self.api_key_authenticate)
|
||||
])(self.add_instance_endpoint)
|
||||
|
||||
async def validate_json_request(self, raw_request: Request):
|
||||
content_type = raw_request.headers.get("content-type", "").lower()
|
||||
if content_type != "application/json":
|
||||
raise HTTPException(
|
||||
status_code=415,
|
||||
detail=
|
||||
"Unsupported Media Type: Only 'application/json' is allowed",
|
||||
)
|
||||
|
||||
def api_key_authenticate(self, x_api_key: str = Header(...)):
|
||||
expected_api_key = os.environ.get("ADMIN_API_KEY")
|
||||
if not expected_api_key:
|
||||
logger.error("ADMIN_API_KEY is not set in the environment.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Server configuration error.",
|
||||
)
|
||||
if x_api_key != expected_api_key:
|
||||
logger.warning("Unauthorized access attempt with API Key: %s",
|
||||
x_api_key)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Forbidden: Invalid API Key.",
|
||||
)
|
||||
|
||||
async def validate_instance(self, instance: str) -> bool:
|
||||
url = f"http://{instance}/v1/models"
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=AIOHTTP_TIMEOUT) as client:
|
||||
logger.info("Verifying %s ...", instance)
|
||||
async with client.get(url) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if "data" in data and len(data["data"]) > 0:
|
||||
model_cur = data["data"][0].get("id", "")
|
||||
if model_cur == self.model:
|
||||
logger.info("Instance: %s could be added.",
|
||||
instance)
|
||||
return True
|
||||
else:
|
||||
logger.warning("Mismatch model %s : %s != %s",
|
||||
instance, model_cur, self.model)
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(str(e))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
return False
|
||||
|
||||
async def add_instance_endpoint(self, request: Request):
|
||||
try:
|
||||
data = await request.json()
|
||||
logger.warning(str(data))
|
||||
instance_type = data.get("type")
|
||||
instance = data.get("instance")
|
||||
if instance_type not in ["prefill", "decode"]:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Invalid instance type.")
|
||||
if not instance or ":" not in instance:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Invalid instance format.")
|
||||
host, port_str = instance.split(":")
|
||||
try:
|
||||
if host != "localhost":
|
||||
ipaddress.ip_address(host)
|
||||
port = int(port_str)
|
||||
if not (0 < port < 65536):
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Invalid port number.")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Invalid instance address.") from e
|
||||
|
||||
is_valid = await self.validate_instance(instance)
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Instance validation failed.")
|
||||
|
||||
if instance_type == "prefill":
|
||||
if instance not in self.prefill_instances:
|
||||
self.prefill_instances.append(instance)
|
||||
self.prefill_cycler = itertools.cycle(
|
||||
self.prefill_instances)
|
||||
else:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Instance already exists.")
|
||||
else:
|
||||
if instance not in self.decode_instances:
|
||||
self.decode_instances.append(instance)
|
||||
self.decode_cycler = itertools.cycle(self.decode_instances)
|
||||
else:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Instance already exists.")
|
||||
|
||||
return JSONResponse(content={
|
||||
"message":
|
||||
f"Added {instance} to {instance_type}_instances."
|
||||
})
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.error("Error in add_instance_endpoint: %s", str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
async def forward_request(self, url, data, use_chunked=True):
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
try:
|
||||
async with session.post(url=url, json=data,
|
||||
headers=headers) as response:
|
||||
if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501
|
||||
if use_chunked:
|
||||
async for chunk_bytes in response.content.iter_chunked( # noqa: E501
|
||||
1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
content = await response.read()
|
||||
yield content
|
||||
else:
|
||||
error_content = await response.text()
|
||||
try:
|
||||
error_content = json.loads(error_content)
|
||||
except json.JSONDecodeError:
|
||||
error_content = error_content
|
||||
logger.error("Request failed with status %s: %s",
|
||||
response.status, error_content)
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=
|
||||
f"Request failed with status {response.status}: "
|
||||
f"{error_content}",
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error("ClientError occurred: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=
|
||||
"Bad Gateway: Error communicating with upstream server.",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error: %s", str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
def schedule(self, cycler: itertools.cycle) -> str:
|
||||
return self.scheduling_policy.schedule(cycler)
|
||||
|
||||
async def get_status(self):
|
||||
status = {
|
||||
"prefill_node_count": len(self.prefill_instances),
|
||||
"decode_node_count": len(self.decode_instances),
|
||||
"prefill_nodes": self.prefill_instances,
|
||||
"decode_nodes": self.decode_instances,
|
||||
}
|
||||
return status
|
||||
|
||||
async def create_completion(self, raw_request: Request):
|
||||
try:
|
||||
request = await raw_request.json()
|
||||
|
||||
kv_prepare_request = request.copy()
|
||||
kv_prepare_request["max_tokens"] = 1
|
||||
|
||||
prefill_instance = self.schedule(self.prefill_cycler)
|
||||
try:
|
||||
async for _ in self.forward_request(
|
||||
f"http://{prefill_instance}/v1/completions",
|
||||
kv_prepare_request):
|
||||
continue
|
||||
except HTTPException as http_exc:
|
||||
self.remove_instance_endpoint("prefill", prefill_instance)
|
||||
raise http_exc
|
||||
|
||||
# Perform kv recv and decoding stage
|
||||
decode_instance = self.schedule(self.decode_cycler)
|
||||
|
||||
try:
|
||||
generator = self.forward_request(
|
||||
f"http://{decode_instance}/v1/completions", request)
|
||||
except HTTPException as http_exc:
|
||||
self.remove_instance_endpoint("decode", decode_instance)
|
||||
raise http_exc
|
||||
response = StreamingResponse(generator)
|
||||
return response
|
||||
except Exception:
|
||||
import sys
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg proxy server")
|
||||
print(exc_info)
|
||||
|
||||
async def create_chat_completion(self, raw_request: Request):
|
||||
try:
|
||||
request = await raw_request.json()
|
||||
|
||||
# add params to request
|
||||
kv_prepare_request = request.copy()
|
||||
kv_prepare_request["max_tokens"] = 1
|
||||
|
||||
# prefill stage
|
||||
prefill_instance = self.schedule(self.prefill_cycler)
|
||||
try:
|
||||
async for _ in self.forward_request(
|
||||
f"http://{prefill_instance}/v1/chat/completions",
|
||||
kv_prepare_request):
|
||||
continue
|
||||
except HTTPException as http_exc:
|
||||
self.remove_instance_endpoint("prefill", prefill_instance)
|
||||
raise http_exc
|
||||
# Perform kv recv and decoding stage
|
||||
decode_instance = self.schedule(self.decode_cycler)
|
||||
|
||||
try:
|
||||
generator = self.forward_request(
|
||||
"http://" + decode_instance + "/v1/chat/completions",
|
||||
request)
|
||||
except HTTPException as http_exc:
|
||||
self.remove_instance_endpoint("decode", decode_instance)
|
||||
raise http_exc
|
||||
response = StreamingResponse(content=generator)
|
||||
return response
|
||||
except Exception:
|
||||
exc_info = sys.exc_info()
|
||||
error_messages = [str(e) for e in exc_info if e]
|
||||
print("Error occurred in disagg proxy server")
|
||||
print(error_messages)
|
||||
return StreamingResponse(content=iter(error_messages),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def remove_instance_endpoint(self, instance_type, instance):
|
||||
if (instance_type == "decode" and instance in self.decode_instances):
|
||||
self.decode_instances.remove(instance)
|
||||
self.decode_cycler = itertools.cycle(self.decode_instances)
|
||||
if (instance_type == "prefill" and instance in self.decode_instances):
|
||||
self.prefill_instances.remove(instance)
|
||||
self.prefill_cycler = itertools.cycle(self.decode_instances)
|
||||
|
||||
|
||||
class RoundRobinSchedulingPolicy(SchedulingPolicy):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def schedule(self, cycler: itertools.cycle) -> str:
|
||||
return next(cycler)
|
||||
|
||||
|
||||
class ProxyServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: argparse.Namespace,
|
||||
scheduling_policy: Optional[SchedulingPolicy] = None,
|
||||
create_completion: Optional[Callable[[Request],
|
||||
StreamingResponse]] = None,
|
||||
create_chat_completion: Optional[Callable[[Request],
|
||||
StreamingResponse]] = None,
|
||||
):
|
||||
self.validate_parsed_serve_args(args)
|
||||
self.port = args.port
|
||||
self.proxy_instance = Proxy(
|
||||
prefill_instances=[] if args.prefill is None else args.prefill,
|
||||
decode_instances=[] if args.decode is None else args.decode,
|
||||
model=args.model,
|
||||
scheduling_policy=(scheduling_policy if scheduling_policy
|
||||
is not None else RoundRobinSchedulingPolicy()),
|
||||
custom_create_completion=create_completion,
|
||||
custom_create_chat_completion=create_chat_completion,
|
||||
)
|
||||
|
||||
def validate_parsed_serve_args(self, args: argparse.Namespace):
|
||||
if not args.prefill:
|
||||
raise ValueError("Please specify at least one prefill node.")
|
||||
if not args.decode:
|
||||
raise ValueError("Please specify at least one decode node.")
|
||||
self.validate_instances(args.prefill)
|
||||
self.validate_instances(args.decode)
|
||||
self.verify_model_config(args.prefill, args.model)
|
||||
self.verify_model_config(args.decode, args.model)
|
||||
|
||||
def validate_instances(self, instances: list):
|
||||
for instance in instances:
|
||||
if len(instance.split(":")) != 2:
|
||||
raise ValueError(f"Invalid instance format: {instance}")
|
||||
host, port = instance.split(":")
|
||||
try:
|
||||
if host != "localhost":
|
||||
ipaddress.ip_address(host)
|
||||
port = int(port)
|
||||
if not (0 < port < 65536):
|
||||
raise ValueError(
|
||||
f"Invalid port number in instance: {instance}")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Invalid instance {instance}: {str(e)}") from e
|
||||
|
||||
def verify_model_config(self, instances: list, model: str) -> None:
|
||||
model_suffix = model.split("/")[-1]
|
||||
for instance in instances:
|
||||
try:
|
||||
response = requests.get(f"http://{instance}/v1/models")
|
||||
if response.status_code == 200:
|
||||
model_cur = response.json()["data"][0]["id"]
|
||||
model_cur_suffix = model_cur.split("/")[-1]
|
||||
if model_cur_suffix != model_suffix:
|
||||
raise ValueError(
|
||||
f"{instance} serves a different model: "
|
||||
f"{model_cur} != {model}")
|
||||
else:
|
||||
raise ValueError(f"Cannot get model id from {instance}!")
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(
|
||||
f"Error communicating with {instance}: {str(e)}") from e
|
||||
|
||||
def run_server(self):
|
||||
app = FastAPI()
|
||||
app.include_router(self.proxy_instance.router)
|
||||
config = uvicorn.Config(app, port=self.port, loop="uvloop")
|
||||
server = uvicorn.Server(config)
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Todo: allow more config
|
||||
parser = argparse.ArgumentParser("vLLM disaggregated proxy server.")
|
||||
parser.add_argument("--model",
|
||||
"-m",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name")
|
||||
|
||||
parser.add_argument(
|
||||
"--prefill",
|
||||
"-p",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="List of prefill node URLs (host:port)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode",
|
||||
"-d",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="List of decode node URLs (host:port)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Server port number",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
proxy_server = ProxyServer(args=args)
|
||||
proxy_server.run_server()
|
@ -53,3 +53,8 @@ KVConnectorFactory.register_connector(
|
||||
"LMCacheConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector",
|
||||
"LMCacheConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeStoreConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
|
||||
"MooncakeStoreConnector")
|
@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
MooncakeStore Connector for Distributed Machine Learning Inference
|
||||
|
||||
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
|
||||
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
|
||||
database-style KVStore.
|
||||
"""
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MooncakeStoreConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
self.config = config.kv_transfer_config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
|
||||
self.local_tp_rank = local_rank
|
||||
|
||||
# Init kv_store
|
||||
if self.config.kv_connector == "MooncakeStoreConnector":
|
||||
# Check if MOONCAKE_CONFIG_PATH is set
|
||||
import os
|
||||
use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None
|
||||
|
||||
if not use_mooncake_store:
|
||||
raise ValueError(
|
||||
"To use MooncakeStoreConnector, you need to pass the ENV: "
|
||||
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
||||
else:
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501
|
||||
MooncakeStore)
|
||||
logger.info(
|
||||
"Initializing KVStoreConnector under kv_transfer_config %s",
|
||||
self.config)
|
||||
self.kv_store = MooncakeStore(config)
|
||||
else:
|
||||
logger.error("Can not find %s", self.config.kv_connector)
|
||||
|
||||
assert self.kv_store is not None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
This method is responsible for cleaning up resources related to the
|
||||
connector when it is no longer needed.
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
self.kv_store.close()
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
model_config = model_executable.model.config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
store_key_prefix = self.tensor_hash(current_tokens)
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
kvcache_to_sent = torch.stack((keys, values), dim=0)
|
||||
store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}"
|
||||
self.kv_store.put(store_kvcache_key, kvcache_to_sent)
|
||||
|
||||
hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}"
|
||||
self.kv_store.put(hidden_key,
|
||||
hidden_or_intermediate_states[start_pos:end_pos])
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
bypass_model_exec = True
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# This can happen during inflight batching. See:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
# get roi for current seq
|
||||
load_key_prefix = self.tensor_hash(current_tokens)
|
||||
load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}"
|
||||
remote_kv = self.kv_store.get(load_kvcache_key)
|
||||
hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}"
|
||||
hidden = self.kv_store.get(hidden_key)
|
||||
|
||||
if remote_kv is None or hidden is None:
|
||||
# didn't find any match.
|
||||
bypass_model_exec = False
|
||||
continue
|
||||
|
||||
num_computed_tokens = current_tokens.shape[0]
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# call self.kv_store to get kv layer by layer
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
layer = model_executable.model.layers[layer_id]
|
||||
# get kvcache object
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
# get remote kvcache
|
||||
|
||||
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
|
||||
layer_id]
|
||||
# use ops.reshape_and_cache_flash to put kv into kvcache
|
||||
ops.reshape_and_cache_flash(
|
||||
remote_k.to(key_cache.device),
|
||||
remote_v.to(value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
@staticmethod
|
||||
def tensor_hash(tensor: torch.Tensor) -> int:
|
||||
"""Calculate the hash value of the tensor."""
|
||||
tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
|
||||
hash_object = hashlib.blake2b(tensor_bytes)
|
||||
hash_hex = hash_object.hexdigest()
|
||||
return int(hash_hex[:16], 16)
|
@ -1,11 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This file contains a new class `KVLookupBufferBase` that allows developers to
|
||||
think of KV cache operations as inserting new KV cache entries (`insert`)
|
||||
into the lookup buffer and querying existing KV caches (`drop_select`)
|
||||
This file contains a new class `KVLookupBufferBase` that allows developers to
|
||||
think of KV cache operations as inserting new KV cache entries (`insert`)
|
||||
into the lookup buffer and querying existing KV caches (`drop_select`)
|
||||
from the lookup buffer.
|
||||
|
||||
All distributed communications are abstracted behind this class.
|
||||
This file also contains a new class `KVStoreBufferBase` that allows developers
|
||||
to manage the KVCache buffer as a simple key-value storage buffer with basic
|
||||
put/get operations.
|
||||
|
||||
These classes above are abstracted behind class `KVCacheBufferBase`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@ -14,9 +18,27 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
|
||||
class KVLookupBufferBase(ABC):
|
||||
class KVCacheBufferBase(ABC):
|
||||
"""
|
||||
Abstract base class for a lookup buffer.
|
||||
Abstract base class for a KVCache buffer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
|
||||
This method is responsible for cleaning up resources related to the
|
||||
KVCache buffer when it is no longer needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVLookupBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache lookup buffer.
|
||||
|
||||
This class provides an abstraction for a key-value (KV) cache lookup buffer.
|
||||
|
||||
@ -96,12 +118,55 @@ class KVLookupBufferBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
|
||||
This method is responsible for cleaning up resources related to the
|
||||
lookup buffer when it is no longer needed.
|
||||
class KVStoreBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache storage buffer with key-value semantics.
|
||||
This class provides a simple key-value storage buffer abstract with basic
|
||||
put/get operations, which enables flexible KVCache transfer granular
|
||||
control.
|
||||
|
||||
The functionality is similar to a distributed key-value store, where:
|
||||
- Key: A unique string identifier for the cached entry
|
||||
- Value:
|
||||
- Tensor to be stored and retrieved
|
||||
- None (indicating deletion or empty value)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
"""Store a key-value pair in the buffer.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
value (Optional[torch.Tensor]): Tensor to be stored.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Retrieve a value from the buffer by key.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
Returns:
|
||||
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
|
160
vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py
Normal file
160
vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py
Normal file
@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This file contains a new class `MooncakeStore` that allows developers to
|
||||
think of KV cache transfer operations as putting new KV cache entries
|
||||
into a remote KVStore-based lookup buffer and getting existing KV caches
|
||||
from this remote lookup buffer.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
|
||||
KVStoreBufferBase)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> 'MooncakeStoreConfig':
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=config.get("global_segment_size",
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE),
|
||||
local_buffer_size=config.get("local_buffer_size",
|
||||
DEFAULT_LOCAL_BUFFER_SIZE),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> 'MooncakeStoreConfig':
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeStore(KVStoreBufferBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
):
|
||||
|
||||
try:
|
||||
from mooncake_vllm_adaptor import MooncakeDistributedStore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
|
||||
try:
|
||||
self.store = MooncakeDistributedStore()
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
|
||||
self.store.setup(self.config.local_hostname,
|
||||
self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol, self.config.device_name,
|
||||
self.config.master_server_address)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("Configuration loading failed: %s", e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
# MooncakeDistributedStore will automatically call the destructor, so
|
||||
# it is unnecessary to close it manually.
|
||||
pass
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
# A message queue needs to be introduced before making it asynchronous.
|
||||
if value is not None:
|
||||
self._put_impl(key, value)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# A message queue needs to be introduced before making it asynchronous.
|
||||
value = self._get_impl(key)
|
||||
return value
|
||||
|
||||
def _put_impl(
|
||||
self,
|
||||
key: str,
|
||||
value: torch.Tensor,
|
||||
) -> None:
|
||||
"""Put KVCache to Mooncake Store"""
|
||||
device_id = value.device.index if value.device.type == 'cuda' else -1
|
||||
device_tensor = torch.tensor(device_id, dtype=torch.int32)
|
||||
value_bytes = safetensors_save({
|
||||
"tensor": value,
|
||||
"device_id": device_tensor
|
||||
})
|
||||
try:
|
||||
self.store.put(key, value_bytes)
|
||||
except TypeError as err:
|
||||
logger.error("Failed to put value into Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Put Type Error.") from err
|
||||
|
||||
def _get_impl(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Get KVCache from Mooncake Store"""
|
||||
try:
|
||||
data = self.store.get(key)
|
||||
except TypeError as err:
|
||||
logger.error("Failed to get value from Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Get Type Error.") from err
|
||||
|
||||
if data:
|
||||
loaded_tensors = safetensors_load(data)
|
||||
tensor = loaded_tensors["tensor"]
|
||||
device_id_tensor = loaded_tensors["device_id"]
|
||||
device_id = int(device_id_tensor.item())
|
||||
device = torch.device(
|
||||
'cuda', device_id) if device_id >= 0 else torch.device('cpu')
|
||||
return tensor.to(device)
|
||||
|
||||
return None
|
Loading…
x
Reference in New Issue
Block a user