diff --git a/examples/online_serving/disagg_examples/disagg_proxy_demo.py b/examples/online_serving/disagg_examples/disagg_proxy_demo.py new file mode 100644 index 00000000..a701636f --- /dev/null +++ b/examples/online_serving/disagg_examples/disagg_proxy_demo.py @@ -0,0 +1,450 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file provides a disaggregated prefilling proxy demo to demonstrate an +example usage of XpYd disaggregated prefilling. +We can launch multiple vllm instances (2 for prefill and 2 for decode), and +launch this proxy demo through: + python3 examples/online_serving/disagg_examples/disagg_proxy_demo.py \ + --model $model_name \ + --prefill localhost:8100 localhost:8101 \ + --decode localhost:8200 localhost:8201 \ + --port 8000 + +Note: This demo will be removed once the PDController implemented in PR 15343 +(https://github.com/vllm-project/vllm/pull/15343) supports XpYd. +""" +import argparse +import ipaddress +import itertools +import json +import logging +import os +import sys +from abc import ABC, abstractmethod +from typing import Callable, Optional + +import aiohttp +import requests +import uvicorn +from fastapi import (APIRouter, Depends, FastAPI, Header, HTTPException, + Request, status) +from fastapi.responses import JSONResponse, StreamingResponse + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO) + + +class SchedulingPolicy(ABC): + + @abstractmethod + def schedule(self, cycler: itertools.cycle): + raise NotImplementedError("Scheduling Proxy is not set.") + + +class Proxy: + + def __init__( + self, + prefill_instances: list[str], + decode_instances: list[str], + model: str, + scheduling_policy: SchedulingPolicy, + custom_create_completion: Optional[Callable[[Request], + StreamingResponse]] = None, + custom_create_chat_completion: Optional[Callable[ + [Request], StreamingResponse]] = None, + ): + self.prefill_instances = prefill_instances + self.decode_instances = decode_instances + self.prefill_cycler = itertools.cycle(prefill_instances) + self.decode_cycler = itertools.cycle(decode_instances) + self.model = model + self.scheduling_policy = scheduling_policy + self.custom_create_completion = custom_create_completion + self.custom_create_chat_completion = custom_create_chat_completion + self.router = APIRouter() + self.setup_routes() + + def setup_routes(self): + self.router.post( + "/v1/completions", + dependencies=[ + Depends(self.validate_json_request) + ])(self.custom_create_completion if self. + custom_create_completion else self.create_completion) + self.router.post( + "/v1/chat/completions", + dependencies=[ + Depends(self.validate_json_request) + ])(self.custom_create_chat_completion if self. + custom_create_chat_completion else self.create_chat_completion) + self.router.get("/status", + response_class=JSONResponse)(self.get_status) + self.router.post("/instances/add", + dependencies=[Depends(self.api_key_authenticate) + ])(self.add_instance_endpoint) + + async def validate_json_request(self, raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + if content_type != "application/json": + raise HTTPException( + status_code=415, + detail= + "Unsupported Media Type: Only 'application/json' is allowed", + ) + + def api_key_authenticate(self, x_api_key: str = Header(...)): + expected_api_key = os.environ.get("ADMIN_API_KEY") + if not expected_api_key: + logger.error("ADMIN_API_KEY is not set in the environment.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Server configuration error.", + ) + if x_api_key != expected_api_key: + logger.warning("Unauthorized access attempt with API Key: %s", + x_api_key) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Forbidden: Invalid API Key.", + ) + + async def validate_instance(self, instance: str) -> bool: + url = f"http://{instance}/v1/models" + try: + async with aiohttp.ClientSession( + timeout=AIOHTTP_TIMEOUT) as client: + logger.info("Verifying %s ...", instance) + async with client.get(url) as response: + if response.status == 200: + data = await response.json() + if "data" in data and len(data["data"]) > 0: + model_cur = data["data"][0].get("id", "") + if model_cur == self.model: + logger.info("Instance: %s could be added.", + instance) + return True + else: + logger.warning("Mismatch model %s : %s != %s", + instance, model_cur, self.model) + return False + else: + return False + else: + return False + except aiohttp.ClientError as e: + logger.error(str(e)) + return False + except Exception as e: + logger.error(str(e)) + return False + + async def add_instance_endpoint(self, request: Request): + try: + data = await request.json() + logger.warning(str(data)) + instance_type = data.get("type") + instance = data.get("instance") + if instance_type not in ["prefill", "decode"]: + raise HTTPException(status_code=400, + detail="Invalid instance type.") + if not instance or ":" not in instance: + raise HTTPException(status_code=400, + detail="Invalid instance format.") + host, port_str = instance.split(":") + try: + if host != "localhost": + ipaddress.ip_address(host) + port = int(port_str) + if not (0 < port < 65536): + raise HTTPException(status_code=400, + detail="Invalid port number.") + except Exception as e: + raise HTTPException(status_code=400, + detail="Invalid instance address.") from e + + is_valid = await self.validate_instance(instance) + if not is_valid: + raise HTTPException(status_code=400, + detail="Instance validation failed.") + + if instance_type == "prefill": + if instance not in self.prefill_instances: + self.prefill_instances.append(instance) + self.prefill_cycler = itertools.cycle( + self.prefill_instances) + else: + raise HTTPException(status_code=400, + detail="Instance already exists.") + else: + if instance not in self.decode_instances: + self.decode_instances.append(instance) + self.decode_cycler = itertools.cycle(self.decode_instances) + else: + raise HTTPException(status_code=400, + detail="Instance already exists.") + + return JSONResponse(content={ + "message": + f"Added {instance} to {instance_type}_instances." + }) + except HTTPException as http_exc: + raise http_exc + except Exception as e: + logger.error("Error in add_instance_endpoint: %s", str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e + + async def forward_request(self, url, data, use_chunked=True): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + try: + async with session.post(url=url, json=data, + headers=headers) as response: + if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501 + if use_chunked: + async for chunk_bytes in response.content.iter_chunked( # noqa: E501 + 1024): + yield chunk_bytes + else: + content = await response.read() + yield content + else: + error_content = await response.text() + try: + error_content = json.loads(error_content) + except json.JSONDecodeError: + error_content = error_content + logger.error("Request failed with status %s: %s", + response.status, error_content) + raise HTTPException( + status_code=response.status, + detail= + f"Request failed with status {response.status}: " + f"{error_content}", + ) + except aiohttp.ClientError as e: + logger.error("ClientError occurred: %s", str(e)) + raise HTTPException( + status_code=502, + detail= + "Bad Gateway: Error communicating with upstream server.", + ) from e + except Exception as e: + logger.error("Unexpected error: %s", str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e + + def schedule(self, cycler: itertools.cycle) -> str: + return self.scheduling_policy.schedule(cycler) + + async def get_status(self): + status = { + "prefill_node_count": len(self.prefill_instances), + "decode_node_count": len(self.decode_instances), + "prefill_nodes": self.prefill_instances, + "decode_nodes": self.decode_instances, + } + return status + + async def create_completion(self, raw_request: Request): + try: + request = await raw_request.json() + + kv_prepare_request = request.copy() + kv_prepare_request["max_tokens"] = 1 + + prefill_instance = self.schedule(self.prefill_cycler) + try: + async for _ in self.forward_request( + f"http://{prefill_instance}/v1/completions", + kv_prepare_request): + continue + except HTTPException as http_exc: + self.remove_instance_endpoint("prefill", prefill_instance) + raise http_exc + + # Perform kv recv and decoding stage + decode_instance = self.schedule(self.decode_cycler) + + try: + generator = self.forward_request( + f"http://{decode_instance}/v1/completions", request) + except HTTPException as http_exc: + self.remove_instance_endpoint("decode", decode_instance) + raise http_exc + response = StreamingResponse(generator) + return response + except Exception: + import sys + + exc_info = sys.exc_info() + print("Error occurred in disagg proxy server") + print(exc_info) + + async def create_chat_completion(self, raw_request: Request): + try: + request = await raw_request.json() + + # add params to request + kv_prepare_request = request.copy() + kv_prepare_request["max_tokens"] = 1 + + # prefill stage + prefill_instance = self.schedule(self.prefill_cycler) + try: + async for _ in self.forward_request( + f"http://{prefill_instance}/v1/chat/completions", + kv_prepare_request): + continue + except HTTPException as http_exc: + self.remove_instance_endpoint("prefill", prefill_instance) + raise http_exc + # Perform kv recv and decoding stage + decode_instance = self.schedule(self.decode_cycler) + + try: + generator = self.forward_request( + "http://" + decode_instance + "/v1/chat/completions", + request) + except HTTPException as http_exc: + self.remove_instance_endpoint("decode", decode_instance) + raise http_exc + response = StreamingResponse(content=generator) + return response + except Exception: + exc_info = sys.exc_info() + error_messages = [str(e) for e in exc_info if e] + print("Error occurred in disagg proxy server") + print(error_messages) + return StreamingResponse(content=iter(error_messages), + media_type="text/event-stream") + + def remove_instance_endpoint(self, instance_type, instance): + if (instance_type == "decode" and instance in self.decode_instances): + self.decode_instances.remove(instance) + self.decode_cycler = itertools.cycle(self.decode_instances) + if (instance_type == "prefill" and instance in self.decode_instances): + self.prefill_instances.remove(instance) + self.prefill_cycler = itertools.cycle(self.decode_instances) + + +class RoundRobinSchedulingPolicy(SchedulingPolicy): + + def __init__(self): + super().__init__() + + def schedule(self, cycler: itertools.cycle) -> str: + return next(cycler) + + +class ProxyServer: + + def __init__( + self, + args: argparse.Namespace, + scheduling_policy: Optional[SchedulingPolicy] = None, + create_completion: Optional[Callable[[Request], + StreamingResponse]] = None, + create_chat_completion: Optional[Callable[[Request], + StreamingResponse]] = None, + ): + self.validate_parsed_serve_args(args) + self.port = args.port + self.proxy_instance = Proxy( + prefill_instances=[] if args.prefill is None else args.prefill, + decode_instances=[] if args.decode is None else args.decode, + model=args.model, + scheduling_policy=(scheduling_policy if scheduling_policy + is not None else RoundRobinSchedulingPolicy()), + custom_create_completion=create_completion, + custom_create_chat_completion=create_chat_completion, + ) + + def validate_parsed_serve_args(self, args: argparse.Namespace): + if not args.prefill: + raise ValueError("Please specify at least one prefill node.") + if not args.decode: + raise ValueError("Please specify at least one decode node.") + self.validate_instances(args.prefill) + self.validate_instances(args.decode) + self.verify_model_config(args.prefill, args.model) + self.verify_model_config(args.decode, args.model) + + def validate_instances(self, instances: list): + for instance in instances: + if len(instance.split(":")) != 2: + raise ValueError(f"Invalid instance format: {instance}") + host, port = instance.split(":") + try: + if host != "localhost": + ipaddress.ip_address(host) + port = int(port) + if not (0 < port < 65536): + raise ValueError( + f"Invalid port number in instance: {instance}") + except Exception as e: + raise ValueError( + f"Invalid instance {instance}: {str(e)}") from e + + def verify_model_config(self, instances: list, model: str) -> None: + model_suffix = model.split("/")[-1] + for instance in instances: + try: + response = requests.get(f"http://{instance}/v1/models") + if response.status_code == 200: + model_cur = response.json()["data"][0]["id"] + model_cur_suffix = model_cur.split("/")[-1] + if model_cur_suffix != model_suffix: + raise ValueError( + f"{instance} serves a different model: " + f"{model_cur} != {model}") + else: + raise ValueError(f"Cannot get model id from {instance}!") + except requests.RequestException as e: + raise ValueError( + f"Error communicating with {instance}: {str(e)}") from e + + def run_server(self): + app = FastAPI() + app.include_router(self.proxy_instance.router) + config = uvicorn.Config(app, port=self.port, loop="uvloop") + server = uvicorn.Server(config) + server.run() + + +if __name__ == "__main__": + # Todo: allow more config + parser = argparse.ArgumentParser("vLLM disaggregated proxy server.") + parser.add_argument("--model", + "-m", + type=str, + required=True, + help="Model name") + + parser.add_argument( + "--prefill", + "-p", + type=str, + nargs="+", + help="List of prefill node URLs (host:port)", + ) + + parser.add_argument( + "--decode", + "-d", + type=str, + nargs="+", + help="List of decode node URLs (host:port)", + ) + + parser.add_argument( + "--port", + type=int, + default=8000, + help="Server port number", + ) + args = parser.parse_args() + proxy_server = ProxyServer(args=args) + proxy_server.run_server() diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 7336c54e..e37ce6dc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -53,3 +53,8 @@ KVConnectorFactory.register_connector( "LMCacheConnector", "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", "LMCacheConnector") + +KVConnectorFactory.register_connector( + "MooncakeStoreConnector", + "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", + "MooncakeStoreConnector") \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py new file mode 100644 index 00000000..c5135dab --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MooncakeStore Connector for Distributed Machine Learning Inference + +The MooncakeStoreConnector transfers KV caches between prefill vLLM workers +(KV cache producer) and decode vLLM workers (KV cache consumer) using a +database-style KVStore. +""" +import hashlib +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class MooncakeStoreConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + + self.local_tp_rank = local_rank + + # Init kv_store + if self.config.kv_connector == "MooncakeStoreConnector": + # Check if MOONCAKE_CONFIG_PATH is set + import os + use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None + + if not use_mooncake_store: + raise ValueError( + "To use MooncakeStoreConnector, you need to pass the ENV: " + "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") + else: + from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501 + MooncakeStore) + logger.info( + "Initializing KVStoreConnector under kv_transfer_config %s", + self.config) + self.kv_store = MooncakeStore(config) + else: + logger.error("Can not find %s", self.config.kv_connector) + + assert self.kv_store is not None + + def close(self) -> None: + """Close the buffer and release resources. + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + self.kv_store.close() + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + model_config = model_executable.model.config + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + current_tokens = input_tokens_tensor[start_pos:end_pos] + store_key_prefix = self.tensor_hash(current_tokens) + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + kvcache_to_sent = torch.stack((keys, values), dim=0) + store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}" + self.kv_store.put(store_kvcache_key, kvcache_to_sent) + + hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}" + self.kv_store.put(hidden_key, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + bypass_model_exec = True + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + hidden_or_intermediate_states_for_one_req = [] + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + + # get roi for current seq + load_key_prefix = self.tensor_hash(current_tokens) + load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}" + remote_kv = self.kv_store.get(load_kvcache_key) + hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}" + hidden = self.kv_store.get(hidden_key) + + if remote_kv is None or hidden is None: + # didn't find any match. + bypass_model_exec = False + continue + + num_computed_tokens = current_tokens.shape[0] + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # call self.kv_store to get kv layer by layer + for layer_id in range(start_layer, end_layer): + layer = model_executable.model.layers[layer_id] + # get kvcache object + kv_cache = kv_caches[layer_id - start_layer] + key_cache, value_cache = kv_cache[0], kv_cache[1] + # get remote kvcache + + remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ + layer_id] + # use ops.reshape_and_cache_flash to put kv into kvcache + ops.reshape_and_cache_flash( + remote_k.to(key_cache.device), + remote_v.to(value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + @staticmethod + def tensor_hash(tensor: torch.Tensor) -> int: + """Calculate the hash value of the tensor.""" + tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() + hash_object = hashlib.blake2b(tensor_bytes) + hash_hex = hash_object.hexdigest() + return int(hash_hex[:16], 16) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index 845da7c5..bea42846 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 """ -This file contains a new class `KVLookupBufferBase` that allows developers to -think of KV cache operations as inserting new KV cache entries (`insert`) -into the lookup buffer and querying existing KV caches (`drop_select`) +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) from the lookup buffer. -All distributed communications are abstracted behind this class. +This file also contains a new class `KVStoreBufferBase` that allows developers +to manage the KVCache buffer as a simple key-value storage buffer with basic +put/get operations. + +These classes above are abstracted behind class `KVCacheBufferBase`. """ from abc import ABC, abstractmethod @@ -14,9 +18,27 @@ from typing import List, Optional import torch -class KVLookupBufferBase(ABC): +class KVCacheBufferBase(ABC): """ - Abstract base class for a lookup buffer. + Abstract base class for a KVCache buffer. + """ + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + KVCache buffer when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + +class KVLookupBufferBase(KVCacheBufferBase): + """ + Abstract base class for a KVCache lookup buffer. This class provides an abstraction for a key-value (KV) cache lookup buffer. @@ -96,12 +118,55 @@ class KVLookupBufferBase(ABC): """ raise NotImplementedError - @abstractmethod - def close(self) -> None: - """Close the buffer and release resources. - This method is responsible for cleaning up resources related to the - lookup buffer when it is no longer needed. +class KVStoreBufferBase(KVCacheBufferBase): + """ + Abstract base class for a KVCache storage buffer with key-value semantics. + This class provides a simple key-value storage buffer abstract with basic + put/get operations, which enables flexible KVCache transfer granular + control. + + The functionality is similar to a distributed key-value store, where: + - Key: A unique string identifier for the cached entry + - Value: + - Tensor to be stored and retrieved + - None (indicating deletion or empty value) + """ + + @abstractmethod + def put( + self, + key: str, + value: Optional[torch.Tensor], + ) -> None: + """Store a key-value pair in the buffer. + + Args: + key (str): Unique identifier for a tensor, this tensor could be the + key cache tensor, value cache tensor, or hidden state tensor + generated during model forwarding. + + value (Optional[torch.Tensor]): Tensor to be stored. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def get( + self, + key: str, + ) -> Optional[torch.Tensor]: + """Retrieve a value from the buffer by key. + + Args: + key (str): Unique identifier for a tensor, this tensor could be the + key cache tensor, value cache tensor, or hidden state tensor + generated during model forwarding. + + Returns: + Optional[torch.Tensor]: Stored tensor if exists, None otherwise. Raises: NotImplementedError: This method must be implemented in subclasses. diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py new file mode 100644 index 00000000..7fd59672 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file contains a new class `MooncakeStore` that allows developers to +think of KV cache transfer operations as putting new KV cache entries +into a remote KVStore-based lookup buffer and getting existing KV caches +from this remote lookup buffer. +""" +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch +from safetensors.torch import load as safetensors_load +from safetensors.torch import save as safetensors_save + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVStoreBufferBase) +from vllm.logger import init_logger + +DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB + +logger = init_logger(__name__) + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + + @staticmethod + def from_file(file_path: str) -> 'MooncakeStoreConfig': + """Load the config from a JSON file.""" + with open(file_path) as fin: + config = json.load(fin) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get("global_segment_size", + DEFAULT_GLOBAL_SEGMENT_SIZE), + local_buffer_size=config.get("local_buffer_size", + DEFAULT_LOCAL_BUFFER_SIZE), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address"), + ) + + @staticmethod + def load_from_env() -> 'MooncakeStoreConfig': + """Load config from a file specified in the environment variable.""" + config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + if config_file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_file_path) + + +class MooncakeStore(KVStoreBufferBase): + + def __init__( + self, + config: VllmConfig, + ): + + try: + from mooncake_vllm_adaptor import MooncakeDistributedStore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + + try: + self.store = MooncakeDistributedStore() + self.config = MooncakeStoreConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + + self.store.setup(self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, self.config.device_name, + self.config.master_server_address) + + except ValueError as e: + logger.error("Configuration loading failed: %s", e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + + def close(self): + # MooncakeDistributedStore will automatically call the destructor, so + # it is unnecessary to close it manually. + pass + + def put( + self, + key: str, + value: Optional[torch.Tensor], + ) -> None: + # A message queue needs to be introduced before making it asynchronous. + if value is not None: + self._put_impl(key, value) + + def get( + self, + key: str, + ) -> Optional[torch.Tensor]: + # A message queue needs to be introduced before making it asynchronous. + value = self._get_impl(key) + return value + + def _put_impl( + self, + key: str, + value: torch.Tensor, + ) -> None: + """Put KVCache to Mooncake Store""" + device_id = value.device.index if value.device.type == 'cuda' else -1 + device_tensor = torch.tensor(device_id, dtype=torch.int32) + value_bytes = safetensors_save({ + "tensor": value, + "device_id": device_tensor + }) + try: + self.store.put(key, value_bytes) + except TypeError as err: + logger.error("Failed to put value into Mooncake Store: %s", err) + raise TypeError("Mooncake Store Put Type Error.") from err + + def _get_impl( + self, + key: str, + ) -> Optional[torch.Tensor]: + """Get KVCache from Mooncake Store""" + try: + data = self.store.get(key) + except TypeError as err: + logger.error("Failed to get value from Mooncake Store: %s", err) + raise TypeError("Mooncake Store Get Type Error.") from err + + if data: + loaded_tensors = safetensors_load(data) + tensor = loaded_tensors["tensor"] + device_id_tensor = loaded_tensors["device_id"] + device_id = int(device_id_tensor.item()) + device = torch.device( + 'cuda', device_id) if device_id >= 0 else torch.device('cpu') + return tensor.to(device) + + return None