diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py new file mode 100644 index 00000000..66efbc0c --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# Read prompts from output.txt +prompts = [] +try: + with open("output.txt") as f: + for line in f: + prompts.append(line.strip()) + print(f"Loaded {len(prompts)} prompts from output.txt") +except FileNotFoundError: + print("Error: output.txt file not found") + exit(-1) + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + +llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=KVTransferConfig.from_cli( + '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",' + '"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}' + )) #, max_model_len=2048, max_num_batched_tokens=2048) + +# 1ST generation (prefill instance) +outputs = llm.generate(prompts, sampling_params) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py new file mode 100644 index 00000000..f7cbf655 --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +context = "Hi " * 1000 +context2 = "Hey " * 500 +prompts = [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", +] + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + +llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig.from_cli( + '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", ' + '"kv_connector_extra_config": ' + '{"shared_storage_path": "local_storage"}}') + ) #, max_model_len=2048, max_num_batched_tokens=2048) + +# 1ST generation (prefill instance) +outputs = llm.generate( + prompts, + sampling_params, +) + +new_prompts = [] +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +# Write new_prompts to output.txt +with open("output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") +print(f"Saved {len(new_prompts)} prompts to output.txt") diff --git a/examples/offline_inference/disaggregated-prefill-v1/run.sh b/examples/offline_inference/disaggregated-prefill-v1/run.sh new file mode 100644 index 00000000..0ebf45a1 --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/run.sh @@ -0,0 +1,5 @@ +rm -rf local_storage/ +rm output.txt + +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index bc17ca32..691ca59b 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Optional +from unittest.mock import Mock import pytest import torch -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import SchedulerOutput @@ -25,6 +27,9 @@ def create_scheduler( enable_prefix_caching: Optional[bool] = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, + use_kv_connector: bool = False, + num_blocks: int = 10000, + block_size: int = 16, ) -> Scheduler: '''Create scheduler under test. @@ -60,31 +65,36 @@ def create_scheduler( 'enable_prefix_caching': enable_prefix_caching }) cache_config = CacheConfig( - block_size=16, + block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", **kwargs_cache, ) + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, + kv_transfer_config=kv_transfer_config, ) kv_cache_config = KVCacheConfig( - num_blocks=10000, # A large number of blocks to hold all requests + num_blocks=num_blocks, # A large number of blocks to hold all requests tensors={}, kv_cache_groups=[ KVCacheGroupSpec(['layer'], - FullAttentionSpec(16, 1, 1, torch.float32, False)) + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) ], ) - cache_config.num_gpu_blocks = 10000 + cache_config.num_gpu_blocks = num_blocks return Scheduler( - scheduler_config, - model_config, - cache_config, - lora_config=None, + vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), @@ -761,3 +771,390 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): stats = scheduler_stats.spec_decoding_stats assert stats.num_draft_tokens == expected[0] assert stats.num_accepted_tokens == expected[1] + + +def _assert_right_scheduler_output( + output: SchedulerOutput, + num_requests: int, + expected_num_scheduled_tokens: int, +): + """Check if SchedulerOutput is correct after remote KV cache hit.""" + + # We should inject the kv_connector_metadata. + assert len(output.kv_connector_metadata.requests) == num_requests + + # Only num_tokens - matched_num_new_tokens should be scheduled. + for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): + assert num_scheduled_tokens == expected_num_scheduled_tokens + + +def _assert_right_kv_cache_manager( + scheduler: Scheduler, + req_ids: list[str], + num_tokens: int, + block_size: int, + num_requests: int, + num_total_blocks: int, +): + """Check whether KVCacheManager is correct after allocate.""" + + # Make sure the request stats are right. + EXPECTED_ACTUAL_BLOCKS = num_tokens // block_size + EXPECTED_TOTAL_BLOCKS = (EXPECTED_ACTUAL_BLOCKS + + scheduler.kv_cache_manager.num_preallocate_blocks) + for req_id in req_ids: + blocks = scheduler.kv_cache_manager.req_to_blocks[req_id] + hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] + assert (scheduler.kv_cache_manager.num_cached_block[req_id] == + EXPECTED_ACTUAL_BLOCKS) + assert len(blocks) == EXPECTED_TOTAL_BLOCKS + assert len(hashes) == EXPECTED_ACTUAL_BLOCKS + + # Make sure we actually touched all the blocks. + BLOCKS_PER_REQ = (num_tokens / block_size + + scheduler.kv_cache_manager.num_preallocate_blocks) + assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == + num_total_blocks - num_requests * BLOCKS_PER_REQ) + + +def _step_until_done( + scheduler: Scheduler, + output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, +): + """Loop over schedule(), update_from_output() until finished.""" + + all_finished = False + _ = scheduler.update_from_output(output, model_runner_output) + while not all_finished: + # Schedule + a few iterations until stopping. + output = scheduler.schedule() + assert len(scheduler.running) + for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): + # We should be in the decode phase now. + assert num_scheduled_tokens == 1 + assert len(output.kv_connector_metadata.requests) == 0 + ecos = scheduler.update_from_output(output, model_runner_output) + all_done = True + for eco in ecos.outputs: + if eco.finish_reason is None: + all_done = False + all_finished = all_done + + +def test_kv_connector_basic(): + """ + Test whether Scheduler with KVConnector schedules tokens, allocates + memory, and cleans up requests as expected under normal operation. + """ + + # Setup Scheduler. + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + ) + NUM_TOTAL_BLOCKS = ( + scheduler.kv_cache_manager.block_pool.get_num_free_blocks()) + BLOCK_SIZE = scheduler.cache_config.block_size + + # Mock External Cache Hit. + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS) + + ###################################################### + # FIRST SET OF REQUESTS - External Hit Only + NUM_REQUESTS = 2 + NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2 + MAX_TOKENS = 3 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # Ensure ScheduleOutput is correct. + output = scheduler.schedule() + _assert_right_scheduler_output( + output=output, + num_requests=NUM_REQUESTS, + # Just the incremental tokens should be scheduled. + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, + ) + + # Ensure KVCacheManager is correct. + _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + NUM_REQUESTS, NUM_TOTAL_BLOCKS) + + # Continue Generation until done. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + _ = scheduler.schedule() + # Confirm we clean up the memory properly. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_TOTAL_BLOCKS + + ###################################################### + # SECOND SET OF REQUESTS - Local And External Hit + NUM_TOKENS_PREFIX = NUM_TOKENS + # We will get a local prefix cache hit for the first + # NUM_TOKENS_PREFIX tokens since they are used above. + NUM_TOKENS = NUM_TOKENS_PREFIX * 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # We should get a local cache hit of NUM_TOKENS_PREFIX and + # a remote KV cache hit of NUM_MATCHED_NEW_TOKENS. + output = scheduler.schedule() + _assert_right_scheduler_output( + output=output, + num_requests=NUM_REQUESTS, + # Just the incremental tokens after local + remote cache hit. + expected_num_scheduled_tokens=(NUM_TOKENS - NUM_TOKENS_PREFIX - + NUM_MATCHED_NEW_TOKENS)) + + # Ensure KVCacheManager is correct. + _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + NUM_REQUESTS, NUM_TOTAL_BLOCKS) + + # Continue Generation until done. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + _ = scheduler.schedule() + # Confirm we clean up the memory properly. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_TOTAL_BLOCKS + + +def test_kv_connector_unable_to_allocate(): + """ + Test whether scheduler with KVConnector is able to handle + unable to allocate (run out of blocks in allocate_slots(). + """ + + # Setup Scheduler With Mock External Cache Hit. + BLOCK_SIZE = 4 + NUM_BLOCKS = 10 + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + ) + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS) + + # Create two requests. The second request will not be able to + # allocate slots because it will not have enough blocks. + NUM_REQUESTS = 2 + NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE + MAX_TOKENS = 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # Just one request should be running. + output = scheduler.schedule() + _assert_right_scheduler_output(output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - + NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # All memory should be freed, with one request waiting. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Just one request should be running. + output = scheduler.schedule() + _assert_right_scheduler_output(output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - + NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # All memory should be freed, with no requests waiting / running. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + +def test_kv_connector_handles_preemption(): + """ + Test whether scheduler with KVConnector is able to handle + unable to allocate (run out of blocks in allocate_slots(). + """ + + # Setup Scheduler With Mock External Cache Hit. + BLOCK_SIZE = 2 + # NOTE: there is 1 null block, so this is 6 blocks. + NUM_BLOCKS = 7 + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + ) + scheduler.kv_cache_manager.num_preallocate_blocks = 0 + + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS) + + # Create two requests. + # Both can be scheduled at first, but the second request + # will be preempted and re-scheduled. + NUM_REQUESTS = 2 + NUM_TOKENS = BLOCK_SIZE * 2 + 1 + MAX_TOKENS = BLOCK_SIZE * 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # All can be scheduled - 1st token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # 2 remote kv cache hits. + num_requests=2, + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 2 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # All can be scheduled - 2nd token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 2 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # This will generate a new block and cause a preemption - 3rd token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Only 1 can be scheduled - 4th (and last token). + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + # All memory should be freed since nothing is running. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + + # Restarts the preempted request - generate 3rd token. + # This will have a local and remote cache hit. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # 1 remote kv_cache hit! + num_requests=1, + # Only 1 block was preempted and there is a single + # remote hit. So only single new token is scheduled. + expected_num_scheduled_tokens=1, + ) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Only 1 can be scheduled - 4th (and last token). + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 0 + # All memory should be freed since nothing is running. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index dbf4723e..68452f4c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -10,6 +10,9 @@ import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( @@ -329,17 +332,54 @@ class MultiHeadAttention(nn.Module): return out.reshape(bsz, q_len, -1) +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + connector.wait_for_layer_load(layer_name) + + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) + + def unified_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output def unified_attention_fake( @@ -367,6 +407,7 @@ def unified_attention_with_output( output: torch.Tensor, layer_name: str, ) -> None: + wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] @@ -379,6 +420,8 @@ def unified_attention_with_output( attn_metadata, output=output) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + def unified_attention_with_output_fake( query: torch.Tensor, diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index e69de29b..ec07c6fe 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.distributed.kv_transfer.kv_transfer_state import ( + ensure_kv_transfer_initialized, get_kv_transfer_group, + has_kv_transfer_group, is_v1_kv_transfer_group) + +__all__ = [ + "get_kv_transfer_group", "has_kv_transfer_group", + "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", + "KVConnectorBaseType" +] diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 57c764b4..0d1a3d40 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Tuple, Union import torch +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.sequence import IntermediateTensors if TYPE_CHECKING: @@ -121,3 +122,6 @@ class KVConnectorBase(ABC): """ raise NotImplementedError + + +KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index e37ce6dc..665ea2f5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -3,14 +3,22 @@ import importlib from typing import TYPE_CHECKING, Callable, Dict, Type +import vllm.envs as envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.logger import init_logger + from .base import KVConnectorBase if TYPE_CHECKING: from vllm.config import VllmConfig +logger = init_logger(__name__) + class KVConnectorFactory: - _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} + _registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {} @classmethod def register_connector(cls, name: str, module_path: str, @@ -19,22 +27,51 @@ class KVConnectorFactory: if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") - def loader() -> Type[KVConnectorBase]: + def loader() -> Type[KVConnectorBaseType]: module = importlib.import_module(module_path) return getattr(module, class_name) cls._registry[name] = loader @classmethod - def create_connector(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: + def create_connector_v0(cls, rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + if envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V0 Connector, " + f"but found {envs.VLLM_USE_V1=}") + connector_name = config.kv_transfer_config.kv_connector if connector_name not in cls._registry: raise ValueError(f"Unsupported connector type: {connector_name}") connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, KVConnectorBase) return connector_cls(rank, local_rank, config) + @classmethod + def create_connector_v1( + cls, + config: "VllmConfig", + role: KVConnectorRole, + ) -> KVConnectorBase_V1: + if not envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V1 Connector, " + f"but found {envs.VLLM_USE_V1=}") + + connector_name = config.kv_transfer_config.kv_connector + connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, KVConnectorBase_V1) + logger.info("Creating v1 connector with name: %s", connector_name) + # NOTE(Kuntai): v1 connector is explicitly separated into two roles. + # Scheduler connector: + # - Co-locate with scheduler process + # - Should only be used inside the Scheduler class + # Worker connector: + # - Co-locate with worker process + # - Should only be used inside the forward context & attention layer + # We build separately to enforce strict separation + return connector_cls(config, role) + # Register various connectors here. # The registration should not be done in each individual file, as we want to @@ -57,4 +94,9 @@ KVConnectorFactory.register_connector( KVConnectorFactory.register_connector( "MooncakeStoreConnector", "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", - "MooncakeStoreConnector") \ No newline at end of file + "MooncakeStoreConnector") + +KVConnectorFactory.register_connector( + "SharedStorageConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", + "SharedStorageConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py new file mode 100644 index 00000000..a017b140 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorRole) + +__all__ = [ + "KVConnectorRole", + "KVConnectorBase_V1", +] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py new file mode 100644 index 00000000..95967d2c --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State +communication in vLLM v1 + +The class provides the following primitives: + Scheduler-side: runs in the scheduler, binds metadata, which + is used by the worker-side to load/save KV cache. + get_num_new_matched_tokens() - get number of new tokens + that exist in the remote KV cache + update_state_after_alloc() - update KVConnector state after + temporary buffer alloc by the CacheManager. + + Worker-side: runs in each worker, loads/saves KV cache to/from + the Connector based on the metadata. + start_load_kv() - starts loading all KVs (maybe async) + wait_for_layer_load() - blocks until layer i load is done + + save_kv_layer() - starts saving KV for layer i (maybe async) + wait_for_save() - blocks until all saves are done +""" + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class KVConnectorRole(enum.Enum): + # Connector running in the scheduler process + SCHEDULER = 0 + + # Connector running in the worker process + WORKER = 1 + + +@dataclass +class KVConnectorMetadata: + pass + + +class KVConnectorBase_V1(ABC): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + logger.warning( + "Initializing KVConnectorBase_V1. This API is experimental and " + "subject to change in the future as we iterate the design.") + self._connector_metadata = KVConnectorMetadata() + self._vllm_config = vllm_config + self._role = role + + @property + def role(self) -> KVConnectorRole: + return self._role + + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + KV cache loading and saving. + + Args: + connector_metadata (dict): the connector metadata. + """ + self._connector_metadata = connector_metadata + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self._connector_metadata = KVConnectorMetadata() + + def _get_connector_metadata(self) -> KVConnectorMetadata: + """Get the connector metadata. + + This function should only be called inside the connector. + + Returns: + ConnectorMetadata: the connector metadata. + """ + return self._connector_metadata + + # ============================== + # Worker-side methods + # ============================== + + @abstractmethod + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + pass + + @abstractmethod + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + pass + + @abstractmethod + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + pass + + @abstractmethod + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + pass + + # ============================== + # Scheduler-side methods + # ============================== + @abstractmethod + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + pass + + @abstractmethod + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + pass + + @abstractmethod + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py new file mode 100644 index 00000000..1d204078 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 +import hashlib +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import safetensors +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class ReqMeta: + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + # Is store or load + is_store: bool + + @staticmethod + def make_meta(token_ids: list[int], block_ids: list[int], block_size: int, + is_store: bool) -> "ReqMeta": + valid_num_tokens = align_to_block_size(len(token_ids), block_size) + token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens] + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = block_offsets.reshape((1, block_size)) + \ + block_ids_tensor.reshape((num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten()[:valid_num_tokens] + return ReqMeta( + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + is_store=is_store, + ) + + +@dataclass +class SharedStorageConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request( + self, + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + ) -> None: + self.requests.append( + ReqMeta.make_meta(token_ids, block_ids, block_size, is_store)) + + +class SharedStorageConnector(KVConnectorBase_V1): + # NOTE: This is Simple debug implementation of the KV connector. + # It save / load the KV cache to / from the disk. + # It does extra work which will overwrite the existing prefix-cache in GPU + # - to remove the overhead, need to add some "mask" in the ReqMeta class + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._block_size = vllm_config.cache_config.block_size + self._requests_need_load: dict[str, Request] = {} + transfer_config = vllm_config.kv_transfer_config + self._storage_path = transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp") + logger.info(vllm_config.kv_transfer_config) + logger.info("Shared storage path is %s", self._storage_path) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + attn_metadata = forward_context.attn_metadata + + def inject_kv_into_layer( + dst_kv_cache_layer: torch.Tensor, + src_kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Inject the KV cache into the layer. + + Args: + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not + using MLA, [num_pages, page_size, xxx] otherwise. + src_kv_cache (torch.Tensor): the source KV cache. In shape + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + otherwise. + slot_mapping (torch.Tensor): the slot mapping. In shape + [num_tokens]. + """ + dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + if isinstance(attn_metadata, MLACommonMetadata): + num_pages = dst_kv_cache_layer_shape[0] + page_size = dst_kv_cache_layer_shape[1] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + num_pages * page_size, -1) + dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + else: + num_pages = dst_kv_cache_layer_shape[1] + page_size = dst_kv_cache_layer_shape[2] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + 2, num_pages * page_size, -1) + dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + + # Get the metadata + metadata: KVConnectorMetadata = \ + self._get_connector_metadata() + assert isinstance(metadata, SharedStorageConnectorMetadata) + + if metadata is None: + logger.warning( + "In connector.start_load_kv, but the connector metadata is None" + ) + return + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.warning( + "In connector.start_load_kv, but the attn_metadata is None") + return + + # Load the KV for each request each layer + for request in metadata.requests: + if request.is_store: + continue + logger.info("Inject KV cache of %d tokens to the paged memory", + len(request.slot_mapping)) + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache_layer = attn_layer.kv_cache[\ + forward_context.virtual_engine] + + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = safetensors.torch.load_file( + filename)["kv_cache"].cuda() + inject_kv_into_layer(kv_cache_layer, kv_cache, + request.slot_mapping) + + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + + def extract_kv_from_layer( + layer: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + + Assume the shape of the layer is (2, num_pages, page_size, xxx) + if MLA is not used, and (num_pages, page_size, xxx) otherwise. + """ + if isinstance(attn_metadata, MLACommonMetadata): + num_pages, page_size = layer.shape[0], layer.shape[1] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, + ...] + num_pages, page_size = layer.shape[1], layer.shape[2] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, + ...] + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, SharedStorageConnectorMetadata) + for request in connector_metadata.requests: + if request.is_store: + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = extract_kv_from_layer(kv_layer, + request.slot_mapping) + tensors = {"kv_cache": kv_cache.detach().cpu()} + safetensors.torch.save_file(tensors, filename) + + def wait_for_save(self): + return + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + + # NOTE: in this debug implementation, we assume that the prompt is + # cached_prompt + newly_generated_single_token + # Therefore, we use prompt_token_ids[:-1] to determine the folder name + + # NOTE: in current v1 scheduler, the num_computed_tokens is aligned + # with the block granularity. And it expects the returned blocks and + # num_computed_tokens to also be aligned with the block granularity. + if not self._found_match_for_request(request): + return 0 + + logger.info("External Cache Hit!") + + # Now, first num_tokens_to_check tokens are hit, we need to prepare + # the metadata for the worker connector to correctly load the KV + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + + return num_tokens_to_check - num_computed_tokens + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + if num_external_tokens > 0: + self._requests_need_load[request.request_id] = request + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = SharedStorageConnectorMetadata() + + total_need_load = 0 + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id in self._requests_need_load: + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size, + is_store=False) + total_need_load += 1 + else: + # NOTE: here, we set the store and load being exclusive, + # but a single request can have both store and load. + # NOTE(rob): for this debug implementation, we only cache + # the original prompt tokens. + if not self._found_match_for_request(new_req): + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size, + is_store=True) + + for cached_req in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): here we rely on the resumed requests being + # the first N requests in the list scheduled_cache_reqs. + if not cached_req.resumed_from_preemption: + break + if cached_req.req_id in self._requests_need_load: + # NOTE(rob): cached_req_data does not have the full + # list of token ids (only new tokens). So we look it + # up in the actual request object. + request = self._requests_need_load[cached_req.req_id] + total_tokens = (len(cached_req.new_token_ids) + + cached_req.num_computed_tokens) + token_ids = request.all_token_ids[:total_tokens] + + # NOTE(rob): For resumed req, new_block_ids is all + # of the block_ids for the request. + block_ids = cached_req.new_block_ids + + meta.add_request(token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False) + total_need_load += 1 + + assert total_need_load == len(self._requests_need_load) + self._requests_need_load.clear() + return meta + + # ============================== + # Helper functions + # ============================== + + def _found_match_for_request( + self, + request: "Request", + ) -> bool: + """Check if the cache is hit for the request. + """ + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + foldername = self._generate_foldername_debug(torch.tensor( + request.prompt_token_ids)[:num_tokens_to_check], + create_folder=False) + return os.path.exists(foldername) + + def _generate_foldername_debug( + self, + input_ids: torch.Tensor, + create_folder=False, + ) -> str: + """Generate a folder name based on the hash of the bytes of the input + ids. + """ + input_ids_bytes = input_ids.numpy().tobytes() + input_ids_hash = hashlib.md5(input_ids_bytes).hexdigest() + foldername = os.path.join(self._storage_path, input_ids_hash) + if create_folder: + os.makedirs(foldername, exist_ok=True) + return foldername + + def _generate_filename_debug( + self, + layer_name: str, + input_ids: torch.Tensor, + ) -> str: + """Generate a file name based on the layer name and the hash + of the bytes of the input ids. + """ + foldername = self._generate_foldername_debug(input_ids, + create_folder=True) + return os.path.join(foldername, f"{layer_name}.safetensors") + + +def align_to_block_size(num_tokens: int, block_size) -> int: + """Align the number of tokens to the block size. + """ + return (num_tokens - 1) // block_size * block_size diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_connector_agent.py similarity index 97% rename from vllm/distributed/kv_transfer/kv_transfer_agent.py rename to vllm/distributed/kv_transfer/kv_connector_agent.py index 1e80e0bd..9d714509 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_connector_agent.py @@ -46,7 +46,7 @@ class KVTransferAgent: assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ "TransferAgent should only be used when kv_connector is set." - self.connector = KVConnectorFactory.create_connector( + self.connector = KVConnectorFactory.create_connector_v0( rank, local_rank, config) def send_kv_caches_and_hidden_states( diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py new file mode 100644 index 00000000..25d2f2cf --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING, Optional + +from vllm import envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.parallel_state import get_world_group + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None + + +def get_kv_transfer_group() -> KVConnectorBaseType: + assert _KV_CONNECTOR_AGENT is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_CONNECTOR_AGENT + + +def has_kv_transfer_group() -> bool: + return _KV_CONNECTOR_AGENT is not None + + +def is_v1_kv_transfer_group( + connector: Optional[KVConnectorBaseType] = None) -> bool: + """Check if the KV connector is the v1 connector. + If the argument is None, it will check the global KV connector + + Args: + connector: The KV connector to check. If None, it will check the + global KV connector. + + Note: + This function will no-longer be needed after the v1 KV connector + becomes the default. + """ + if connector is None: + connector = _KV_CONNECTOR_AGENT + + if connector is None: + return False + + return isinstance(connector, KVConnectorBase_V1) + + +def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_CONNECTOR_AGENT + + if vllm_config.kv_transfer_config is None: + return + + if (vllm_config.kv_transfer_config.is_kv_transfer_instance + and _KV_CONNECTOR_AGENT is None): + if envs.VLLM_USE_V1: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( + config=vllm_config, role=KVConnectorRole.WORKER) + else: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( + rank=get_world_group().rank, + local_rank=get_world_group().local_rank, + config=vllm_config, + ) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e0eeeffb..d0ac7e92 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -29,15 +29,13 @@ from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Union) +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup -import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) @@ -46,9 +44,6 @@ from vllm.logger import init_logger from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, supports_custom_op) -if TYPE_CHECKING: - from vllm.config import VllmConfig - @dataclass class GraphCaptureContext: @@ -772,14 +767,6 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group -_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None - - -def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: - assert _KV_TRANSFER is not None, ( - "disaggregated KV cache transfer parallel group is not initialized") - return _KV_TRANSFER - @contextmanager def graph_capture(device: torch.device): @@ -962,26 +949,6 @@ def initialize_model_parallel( _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) -def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: - """ - Initialize KV cache transfer parallel group. - """ - - global _KV_TRANSFER - - if vllm_config.kv_transfer_config is None: - return - - if all([ - vllm_config.kv_transfer_config.is_kv_transfer_instance, - _KV_TRANSFER is None - ]): - _KV_TRANSFER = kv_transfer.KVTransferAgent( - rank=get_world_group().rank, - local_rank=get_world_group().local_rank, - config=vllm_config) - - def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 85b3ddfc..7c1bde0f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1527,12 +1527,6 @@ class EngineArgs: recommend_to_remove=False) return False - # No Disaggregated Prefill so far. - if self.kv_transfer_config != EngineArgs.kv_transfer_config: - _raise_or_fallback(feature_name="--kv-transfer-config", - recommend_to_remove=False) - return False - # No FlashInfer or XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e195a03c..06790d8e 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,6 +11,10 @@ import torch.distributed as dist import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: @@ -98,6 +102,17 @@ def set_forward_context(attn_metadata: Any, virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata) + + # KVConnector: trigger (possibly async) load before forward. + # Each attn layer will block until the reading is complete. + trigger_kv_transfer = (attn_metadata is not None + and has_kv_transfer_group() + and is_v1_kv_transfer_group()) + if trigger_kv_transfer: + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + kv_connector.start_load_kv(_forward_context) + try: yield finally: @@ -133,4 +148,12 @@ def set_forward_context(attn_metadata: Any, logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), forward_stats) + + # KVConnector: each attn layer triggers (possibly async) save. + # Ensure all those operations complete before forward() is done. + if trigger_kv_transfer: + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + kv_connector.wait_for_save() + _forward_context = prev_context diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 33761cf7..6e5f969d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -171,8 +171,9 @@ class KVCacheManager: Args: request: The request to allocate slots. - num_tokens: The number of tokens to allocate. Note that this does - not include the tokens that have already been computed. + num_tokens: The number of tokens to allocate, including external + tokens. Note that this does not include tokens that have + already been computed locally (i.e. new_computed_blocks). new_computed_blocks: A list of new computed blocks just hitting the prefix caching. num_lookahead_tokens: The number of speculative tokens to allocate. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index dc0d2d59..1d3f1f41 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: import numpy as np import numpy.typing as npt + from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata) from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -121,3 +123,6 @@ class SchedulerOutput: structured_output_request_ids: dict[str, int] # the bitmask for the whole batch grammar_bitmask: Optional[npt.NDArray[np.int32]] + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a8157487..7e658d13 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,8 +7,10 @@ from collections import deque from collections.abc import Iterable from typing import Optional, Union -from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, @@ -34,20 +36,17 @@ class Scheduler(SchedulerInterface): def __init__( self, - scheduler_config: SchedulerConfig, - model_config: ModelConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, structured_output_manager: StructuredOutputManager, - speculative_config: SpeculativeConfig = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, ) -> None: - self.scheduler_config = scheduler_config - self.cache_config = cache_config - self.lora_config = lora_config + self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config self.kv_cache_config = kv_cache_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager @@ -64,11 +63,22 @@ class Scheduler(SchedulerInterface): self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len + # Create KVConnector for the Scheduler. Note that each Worker + # will have a corresponding KVConnector with Role=WORKER. + # KV Connector pushes/pull of remote KVs for P/D and offloading. + self.connector = None + if self.vllm_config.kv_transfer_config is not None: + self.connector = KVConnectorFactory.create_connector_v1( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + + num_gpu_blocks = self.cache_config.num_gpu_blocks + assert num_gpu_blocks is not None and num_gpu_blocks > 0 + # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - enable_caching=cache_config.enable_prefix_caching, + enable_caching=self.cache_config.enable_prefix_caching, caching_hash_algo=self.cache_config.prefix_caching_hash_algo, log_stats=self.log_stats) self.block_size = self.cache_config.block_size @@ -99,8 +109,8 @@ class Scheduler(SchedulerInterface): # This can be changed when we make encoder cache for embedding caching # across requests. encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, mm_registry=mm_registry, ) @@ -115,6 +125,7 @@ class Scheduler(SchedulerInterface): cache_size=encoder_cache_size) self.num_lookahead_tokens = 0 + speculative_config = vllm_config.speculative_config if speculative_config and speculative_config.method == "eagle": self.num_lookahead_tokens = \ speculative_config.num_speculative_tokens @@ -304,6 +315,16 @@ class Scheduler(SchedulerInterface): # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks(request) + + # Get externally-cached tokens if using a KVConnector. + num_external_tokens = ( + 0 if self.connector is None else + self.connector.get_num_new_matched_tokens( + request, num_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens += num_external_tokens + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed requests, @@ -330,11 +351,21 @@ class Scheduler(SchedulerInterface): new_encoder_budget = encoder_budget new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, num_new_tokens + num_external_tokens, + computed_blocks) if new_blocks is None: # The request cannot be scheduled. break + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + num_external_tokens, + ) + self.waiting.popleft() if request.use_structured_output: structured_output_request_ids[ @@ -443,6 +474,14 @@ class Scheduler(SchedulerInterface): grammar_bitmask=grammar_bitmask, ) + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the @@ -508,6 +547,9 @@ class Scheduler(SchedulerInterface): If an encoder input cannot be scheduled due to cache or budget limitations, the method adjusts `num_new_tokens` to schedule only the decoder tokens up to just before the unschedulable encoder input. + + Note that num_computed_tokens includes both locally cached + blocks and externally cached blocks (via KVConnector). """ encoder_inputs_to_schedule: list[int] = [] mm_positions = request.mm_positions diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ba5e5050..9c4036ef 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -92,12 +92,8 @@ class EngineCore: vllm_config.scheduler_config.scheduler_cls) self.scheduler: SchedulerInterface = Scheduler( - scheduler_config=vllm_config.scheduler_config, - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - lora_config=vllm_config.lora_config, + vllm_config=vllm_config, kv_cache_config=kv_cache_config, - speculative_config=vllm_config.speculative_config, structured_output_manager=self.structured_output_manager, include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index baf0dfb9..ac0701c4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -13,6 +13,8 @@ import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.logger import init_logger @@ -987,6 +989,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + get_kv_transfer_group().bind_connector_metadata( + scheduler_output.kv_connector_metadata) + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOutput if there's no work to do. @@ -1228,6 +1235,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): # in the next step. del draft_probs + # Clear KVConnector state after all KVs are generated. + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2972e0ff..3a29f8d0 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -9,11 +9,12 @@ import torch.distributed import torch.nn as nn import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -110,7 +111,7 @@ class Worker(WorkerBase): raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, + init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -285,12 +286,13 @@ class Worker(WorkerBase): def init_worker_distributed_environment( - parallel_config: ParallelConfig, + vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, @@ -299,6 +301,8 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(vllm_config) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9524a69f..49b0ba1b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,7 +23,8 @@ from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_kv_transfer_group, get_pp_group +from vllm.distributed import get_pp_group +from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, graph_capture) from vllm.forward_context import get_forward_context, set_forward_context diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d59f20f4..9ea003be 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -10,10 +10,10 @@ import torch.distributed import vllm.envs as envs from vllm.config import VllmConfig from vllm.device_allocator.cumem import CuMemAllocator -from vllm.distributed import (ensure_kv_transfer_initialized, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed