[P/D][V1] KV Connector API V1 (#15960)

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: remi <remi@mistral.ai>
Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
Yihua Cheng 2025-04-17 15:22:40 -05:00 committed by GitHub
parent 0377b8310b
commit 3408e47159
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1377 additions and 83 deletions

View File

@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
# Read prompts from output.txt
prompts = []
try:
with open("output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from output.txt")
except FileNotFoundError:
print("Error: output.txt file not found")
exit(-1)
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=KVTransferConfig.from_cli(
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",'
'"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}'
)) #, max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
context = "Hi " * 1000
context2 = "Hey " * 500
prompts = [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig.from_cli(
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", '
'"kv_connector_extra_config": '
'{"shared_storage_path": "local_storage"}}')
) #, max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)
new_prompts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Write new_prompts to output.txt
with open("output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to output.txt")

View File

@ -0,0 +1,5 @@
rm -rf local_storage/
rm output.txt
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py

View File

@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional from typing import Optional
from unittest.mock import Mock
import pytest import pytest
import torch import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@ -25,6 +27,9 @@ def create_scheduler(
enable_prefix_caching: Optional[bool] = None, enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0, long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False, disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
) -> Scheduler: ) -> Scheduler:
'''Create scheduler under test. '''Create scheduler under test.
@ -60,31 +65,36 @@ def create_scheduler(
'enable_prefix_caching': enable_prefix_caching 'enable_prefix_caching': enable_prefix_caching
}) })
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=16, block_size=block_size,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
swap_space=0, swap_space=0,
cache_dtype="auto", cache_dtype="auto",
**kwargs_cache, **kwargs_cache,
) )
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
vllm_config = VllmConfig( vllm_config = VllmConfig(
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
) )
kv_cache_config = KVCacheConfig( kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests num_blocks=num_blocks, # A large number of blocks to hold all requests
tensors={}, tensors={},
kv_cache_groups=[ kv_cache_groups=[
KVCacheGroupSpec(['layer'], KVCacheGroupSpec(['layer'],
FullAttentionSpec(16, 1, 1, torch.float32, False)) FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
], ],
) )
cache_config.num_gpu_blocks = 10000 cache_config.num_gpu_blocks = num_blocks
return Scheduler( return Scheduler(
scheduler_config, vllm_config=vllm_config,
model_config,
cache_config,
lora_config=None,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
log_stats=True, log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config), structured_output_manager=StructuredOutputManager(vllm_config),
@ -761,3 +771,390 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
stats = scheduler_stats.spec_decoding_stats stats = scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == expected[0] assert stats.num_draft_tokens == expected[0]
assert stats.num_accepted_tokens == expected[1] assert stats.num_accepted_tokens == expected[1]
def _assert_right_scheduler_output(
output: SchedulerOutput,
num_requests: int,
expected_num_scheduled_tokens: int,
):
"""Check if SchedulerOutput is correct after remote KV cache hit."""
# We should inject the kv_connector_metadata.
assert len(output.kv_connector_metadata.requests) == num_requests
# Only num_tokens - matched_num_new_tokens should be scheduled.
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
assert num_scheduled_tokens == expected_num_scheduled_tokens
def _assert_right_kv_cache_manager(
scheduler: Scheduler,
req_ids: list[str],
num_tokens: int,
block_size: int,
num_requests: int,
num_total_blocks: int,
):
"""Check whether KVCacheManager is correct after allocate."""
# Make sure the request stats are right.
EXPECTED_ACTUAL_BLOCKS = num_tokens // block_size
EXPECTED_TOTAL_BLOCKS = (EXPECTED_ACTUAL_BLOCKS +
scheduler.kv_cache_manager.num_preallocate_blocks)
for req_id in req_ids:
blocks = scheduler.kv_cache_manager.req_to_blocks[req_id]
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
assert (scheduler.kv_cache_manager.num_cached_block[req_id] ==
EXPECTED_ACTUAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_ACTUAL_BLOCKS
# Make sure we actually touched all the blocks.
BLOCKS_PER_REQ = (num_tokens / block_size +
scheduler.kv_cache_manager.num_preallocate_blocks)
assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() ==
num_total_blocks - num_requests * BLOCKS_PER_REQ)
def _step_until_done(
scheduler: Scheduler,
output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
):
"""Loop over schedule(), update_from_output() until finished."""
all_finished = False
_ = scheduler.update_from_output(output, model_runner_output)
while not all_finished:
# Schedule + a few iterations until stopping.
output = scheduler.schedule()
assert len(scheduler.running)
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
# We should be in the decode phase now.
assert num_scheduled_tokens == 1
assert len(output.kv_connector_metadata.requests) == 0
ecos = scheduler.update_from_output(output, model_runner_output)
all_done = True
for eco in ecos.outputs:
if eco.finish_reason is None:
all_done = False
all_finished = all_done
def test_kv_connector_basic():
"""
Test whether Scheduler with KVConnector schedules tokens, allocates
memory, and cleans up requests as expected under normal operation.
"""
# Setup Scheduler.
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
)
NUM_TOTAL_BLOCKS = (
scheduler.kv_cache_manager.block_pool.get_num_free_blocks())
BLOCK_SIZE = scheduler.cache_config.block_size
# Mock External Cache Hit.
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
######################################################
# FIRST SET OF REQUESTS - External Hit Only
NUM_REQUESTS = 2
NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2
MAX_TOKENS = 3
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
# Ensure ScheduleOutput is correct.
output = scheduler.schedule()
_assert_right_scheduler_output(
output=output,
num_requests=NUM_REQUESTS,
# Just the incremental tokens should be scheduled.
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# Confirm we clean up the memory properly.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_TOTAL_BLOCKS
######################################################
# SECOND SET OF REQUESTS - Local And External Hit
NUM_TOKENS_PREFIX = NUM_TOKENS
# We will get a local prefix cache hit for the first
# NUM_TOKENS_PREFIX tokens since they are used above.
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
# We should get a local cache hit of NUM_TOKENS_PREFIX and
# a remote KV cache hit of NUM_MATCHED_NEW_TOKENS.
output = scheduler.schedule()
_assert_right_scheduler_output(
output=output,
num_requests=NUM_REQUESTS,
# Just the incremental tokens after local + remote cache hit.
expected_num_scheduled_tokens=(NUM_TOKENS - NUM_TOKENS_PREFIX -
NUM_MATCHED_NEW_TOKENS))
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# Confirm we clean up the memory properly.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_TOTAL_BLOCKS
def test_kv_connector_unable_to_allocate():
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 4
NUM_BLOCKS = 10
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
NUM_REQUESTS = 2
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
MAX_TOKENS = 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
# Just one request should be running.
output = scheduler.schedule()
_assert_right_scheduler_output(output,
num_requests=1,
expected_num_scheduled_tokens=NUM_TOKENS -
NUM_MATCHED_NEW_TOKENS)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# All memory should be freed, with one request waiting.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Just one request should be running.
output = scheduler.schedule()
_assert_right_scheduler_output(output,
num_requests=1,
expected_num_scheduled_tokens=NUM_TOKENS -
NUM_MATCHED_NEW_TOKENS)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# All memory should be freed, with no requests waiting / running.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
def test_kv_connector_handles_preemption():
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 2
# NOTE: there is 1 null block, so this is 6 blocks.
NUM_BLOCKS = 7
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
)
scheduler.kv_cache_manager.num_preallocate_blocks = 0
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
# Create two requests.
# Both can be scheduled at first, but the second request
# will be preempted and re-scheduled.
NUM_REQUESTS = 2
NUM_TOKENS = BLOCK_SIZE * 2 + 1
MAX_TOKENS = BLOCK_SIZE * 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
# All can be scheduled - 1st token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# 2 remote kv cache hits.
num_requests=2,
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS)
assert len(scheduler.running) == 2
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# All can be scheduled - 2nd token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1)
assert len(scheduler.running) == 2
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# This will generate a new block and cause a preemption - 3rd token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Only 1 can be scheduled - 4th (and last token).
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1)
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
# Restarts the preempted request - generate 3rd token.
# This will have a local and remote cache hit.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# 1 remote kv_cache hit!
num_requests=1,
# Only 1 block was preempted and there is a single
# remote hit. So only single new token is scheduled.
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Only 1 can be scheduled - 4th (and last token).
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1)
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1

View File

@ -10,6 +10,9 @@ import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -329,17 +332,54 @@ class MultiHeadAttention(nn.Module):
return out.reshape(bsz, q_len, -1) return out.reshape(bsz, q_len, -1)
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
layer_name: str, layer_name: str,
) -> torch.Tensor: ) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine] kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
def unified_attention_fake( def unified_attention_fake(
@ -367,6 +407,7 @@ def unified_attention_with_output(
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
) -> None: ) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
@ -379,6 +420,8 @@ def unified_attention_with_output(
attn_metadata, attn_metadata,
output=output) output=output)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake( def unified_attention_with_output_fake(
query: torch.Tensor, query: torch.Tensor,

View File

@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized, get_kv_transfer_group,
has_kv_transfer_group, is_v1_kv_transfer_group)
__all__ = [
"get_kv_transfer_group", "has_kv_transfer_group",
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
"KVConnectorBaseType"
]

View File

@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Tuple, Union
import torch import torch
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
if TYPE_CHECKING: if TYPE_CHECKING:
@ -121,3 +122,6 @@ class KVConnectorBase(ABC):
""" """
raise NotImplementedError raise NotImplementedError
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]

View File

@ -3,14 +3,22 @@
import importlib import importlib
from typing import TYPE_CHECKING, Callable, Dict, Type from typing import TYPE_CHECKING, Callable, Dict, Type
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.logger import init_logger
from .base import KVConnectorBase from .base import KVConnectorBase
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
logger = init_logger(__name__)
class KVConnectorFactory: class KVConnectorFactory:
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} _registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {}
@classmethod @classmethod
def register_connector(cls, name: str, module_path: str, def register_connector(cls, name: str, module_path: str,
@ -19,22 +27,51 @@ class KVConnectorFactory:
if name in cls._registry: if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.") raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> Type[KVConnectorBase]: def loader() -> Type[KVConnectorBaseType]:
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
cls._registry[name] = loader cls._registry[name] = loader
@classmethod @classmethod
def create_connector(cls, rank: int, local_rank: int, def create_connector_v0(cls, rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase: config: "VllmConfig") -> KVConnectorBase:
if envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V0 Connector, "
f"but found {envs.VLLM_USE_V1=}")
connector_name = config.kv_transfer_config.kv_connector connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry: if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}") raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]() connector_cls = cls._registry[connector_name]()
assert issubclass(connector_cls, KVConnectorBase)
return connector_cls(rank, local_rank, config) return connector_cls(rank, local_rank, config)
@classmethod
def create_connector_v1(
cls,
config: "VllmConfig",
role: KVConnectorRole,
) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, "
f"but found {envs.VLLM_USE_V1=}")
connector_name = config.kv_transfer_config.kv_connector
connector_cls = cls._registry[connector_name]()
assert issubclass(connector_cls, KVConnectorBase_V1)
logger.info("Creating v1 connector with name: %s", connector_name)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
# - Should only be used inside the Scheduler class
# Worker connector:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return connector_cls(config, role)
# Register various connectors here. # Register various connectors here.
# The registration should not be done in each individual file, as we want to # The registration should not be done in each individual file, as we want to
@ -57,4 +94,9 @@ KVConnectorFactory.register_connector(
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"MooncakeStoreConnector", "MooncakeStoreConnector",
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
"MooncakeStoreConnector") "MooncakeStoreConnector")
KVConnectorFactory.register_connector(
"SharedStorageConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector")

View File

@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole)
__all__ = [
"KVConnectorRole",
"KVConnectorBase_V1",
]

View File

@ -0,0 +1,209 @@
# SPDX-License-Identifier: Apache-2.0
"""
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
communication in vLLM v1
The class provides the following primitives:
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save KV cache.
get_num_new_matched_tokens() - get number of new tokens
that exist in the remote KV cache
update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done
save_kv_layer() - starts saving KV for layer i (maybe async)
wait_for_save() - blocks until all saves are done
"""
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.request import Request
logger = init_logger(__name__)
class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
# Connector running in the worker process
WORKER = 1
@dataclass
class KVConnectorMetadata:
pass
class KVConnectorBase_V1(ABC):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design.")
self._connector_metadata = KVConnectorMetadata()
self._vllm_config = vllm_config
self._role = role
@property
def role(self) -> KVConnectorRole:
return self._role
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
This function should be called by the model runner every time
before the model execution. The metadata will be used for runtime
KV cache loading and saving.
Args:
connector_metadata (dict): the connector metadata.
"""
self._connector_metadata = connector_metadata
def clear_connector_metadata(self) -> None:
"""Clear the connector metadata.
This function should be called by the model runner every time
after the model execution.
"""
self._connector_metadata = KVConnectorMetadata()
def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.
This function should only be called inside the connector.
Returns:
ConnectorMetadata: the connector metadata.
"""
return self._connector_metadata
# ==============================
# Worker-side methods
# ==============================
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
pass
@abstractmethod
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
pass
@abstractmethod
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""
Start saving a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
pass
@abstractmethod
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
pass
# ==============================
# Scheduler-side methods
# ==============================
@abstractmethod
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
pass
@abstractmethod
def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
pass
@abstractmethod
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
pass

View File

@ -0,0 +1,382 @@
# SPDX-License-Identifier: Apache-2.0
import hashlib
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
import safetensors
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Is store or load
is_store: bool
@staticmethod
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
is_store: bool) -> "ReqMeta":
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
is_store=is_store,
)
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))
class SharedStorageConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
transfer_config = vllm_config.kv_transfer_config
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp")
logger.info(vllm_config.kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1)
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
logger.warning(
"In connector.start_load_kv, but the connector metadata is None"
)
return
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
logger.warning(
"In connector.start_load_kv, but the attn_metadata is None")
return
# Load the KV for each request each layer
for request in metadata.requests:
if request.is_store:
continue
logger.info("Inject KV cache of %d tokens to the paged memory",
len(request.slot_mapping))
for layer_name in forward_context.no_compile_layers:
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache_layer = attn_layer.kv_cache[\
forward_context.virtual_engine]
filename = self._generate_filename_debug(
layer_name, request.token_ids)
kv_cache = safetensors.torch.load_file(
filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
layer_name, request.token_ids)
kv_cache = extract_kv_from_layer(kv_layer,
request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
def wait_for_save(self):
return
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0
logger.info("External Cache Hit!")
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens
def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
if num_external_tokens > 0:
self._requests_need_load[request.request_id] = request
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = SharedStorageConnectorMetadata()
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in self._requests_need_load:
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_size=self._block_size,
is_store=False)
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
# but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_request(new_req):
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_size=self._block_size,
is_store=True)
for cached_req in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption:
break
if cached_req.req_id in self._requests_need_load:
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[cached_req.req_id]
total_tokens = (len(cached_req.new_token_ids) +
cached_req.num_computed_tokens)
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = cached_req.new_block_ids
meta.add_request(token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False)
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_request(
self,
request: "Request",
) -> bool:
"""Check if the cache is hit for the request.
"""
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
foldername = self._generate_foldername_debug(torch.tensor(
request.prompt_token_ids)[:num_tokens_to_check],
create_folder=False)
return os.path.exists(foldername)
def _generate_foldername_debug(
self,
input_ids: torch.Tensor,
create_folder=False,
) -> str:
"""Generate a folder name based on the hash of the bytes of the input
ids.
"""
input_ids_bytes = input_ids.numpy().tobytes()
input_ids_hash = hashlib.md5(input_ids_bytes).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(
self,
layer_name: str,
input_ids: torch.Tensor,
) -> str:
"""Generate a file name based on the layer name and the hash
of the bytes of the input ids.
"""
foldername = self._generate_foldername_debug(input_ids,
create_folder=True)
return os.path.join(foldername, f"{layer_name}.safetensors")
def align_to_block_size(num_tokens: int, block_size) -> int:
"""Align the number of tokens to the block size.
"""
return (num_tokens - 1) // block_size * block_size

View File

@ -46,7 +46,7 @@ class KVTransferAgent:
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
"TransferAgent should only be used when kv_connector is set." "TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector( self.connector = KVConnectorFactory.create_connector_v0(
rank, local_rank, config) rank, local_rank, config)
def send_kv_caches_and_hidden_states( def send_kv_caches_and_hidden_states(

View File

@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
if TYPE_CHECKING:
from vllm.config import VllmConfig
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
def get_kv_transfer_group() -> KVConnectorBaseType:
assert _KV_CONNECTOR_AGENT is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_CONNECTOR_AGENT
def has_kv_transfer_group() -> bool:
return _KV_CONNECTOR_AGENT is not None
def is_v1_kv_transfer_group(
connector: Optional[KVConnectorBaseType] = None) -> bool:
"""Check if the KV connector is the v1 connector.
If the argument is None, it will check the global KV connector
Args:
connector: The KV connector to check. If None, it will check the
global KV connector.
Note:
This function will no-longer be needed after the v1 KV connector
becomes the default.
"""
if connector is None:
connector = _KV_CONNECTOR_AGENT
if connector is None:
return False
return isinstance(connector, KVConnectorBase_V1)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_CONNECTOR_AGENT
if vllm_config.kv_transfer_config is None:
return
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
and _KV_CONNECTOR_AGENT is None):
if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
config=vllm_config, role=KVConnectorRole.WORKER)
else:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config,
)

View File

@ -29,15 +29,13 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Union)
from unittest.mock import patch from unittest.mock import patch
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import ( from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase) DeviceCommunicatorBase)
@ -46,9 +44,6 @@ from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op) supports_custom_op)
if TYPE_CHECKING:
from vllm.config import VllmConfig
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
@ -772,14 +767,6 @@ def get_pp_group() -> GroupCoordinator:
# kept for backward compatibility # kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group get_pipeline_model_parallel_group = get_pp_group
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None
def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
assert _KV_TRANSFER is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_TRANSFER
@contextmanager @contextmanager
def graph_capture(device: torch.device): def graph_capture(device: torch.device):
@ -962,26 +949,6 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_TRANSFER
if vllm_config.kv_transfer_config is None:
return
if all([
vllm_config.kv_transfer_config.is_kv_transfer_instance,
_KV_TRANSFER is None
]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config)
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,

View File

@ -1527,12 +1527,6 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No Disaggregated Prefill so far.
if self.kv_transfer_config != EngineArgs.kv_transfer_config:
_raise_or_fallback(feature_name="--kv-transfer-config",
recommend_to_remove=False)
return False
# No FlashInfer or XFormers so far. # No FlashInfer or XFormers so far.
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",

View File

@ -11,6 +11,10 @@ import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
@ -98,6 +102,17 @@ def set_forward_context(attn_metadata: Any,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
dp_metadata=dp_metadata) dp_metadata=dp_metadata)
# KVConnector: trigger (possibly async) load before forward.
# Each attn layer will block until the reading is complete.
trigger_kv_transfer = (attn_metadata is not None
and has_kv_transfer_group()
and is_v1_kv_transfer_group())
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context)
try: try:
yield yield
finally: finally:
@ -133,4 +148,12 @@ def set_forward_context(attn_metadata: Any,
logger.info(("Batchsize forward time stats " logger.info(("Batchsize forward time stats "
"(batchsize, count, median_time(ms)): %s"), "(batchsize, count, median_time(ms)): %s"),
forward_stats) forward_stats)
# KVConnector: each attn layer triggers (possibly async) save.
# Ensure all those operations complete before forward() is done.
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.wait_for_save()
_forward_context = prev_context _forward_context = prev_context

View File

@ -171,8 +171,9 @@ class KVCacheManager:
Args: Args:
request: The request to allocate slots. request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does num_tokens: The number of tokens to allocate, including external
not include the tokens that have already been computed. tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
new_computed_blocks: A list of new computed blocks just hitting the new_computed_blocks: A list of new computed blocks just hitting the
prefix caching. prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate. num_lookahead_tokens: The number of speculative tokens to allocate.

View File

@ -9,6 +9,8 @@ if TYPE_CHECKING:
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -121,3 +123,6 @@ class SchedulerOutput:
structured_output_request_ids: dict[str, int] structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch # the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]] grammar_bitmask: Optional[npt.NDArray[np.int32]]
# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None

View File

@ -7,8 +7,10 @@ from collections import deque
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Optional, Union
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, from vllm.config import VllmConfig
SpeculativeConfig) from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
@ -34,20 +36,17 @@ class Scheduler(SchedulerInterface):
def __init__( def __init__(
self, self,
scheduler_config: SchedulerConfig, vllm_config: VllmConfig,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager, structured_output_manager: StructuredOutputManager,
speculative_config: SpeculativeConfig = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False, include_finished_set: bool = False,
log_stats: bool = False, log_stats: bool = False,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.vllm_config = vllm_config
self.cache_config = cache_config self.scheduler_config = vllm_config.scheduler_config
self.lora_config = lora_config self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.log_stats = log_stats self.log_stats = log_stats
self.structured_output_manager = structured_output_manager self.structured_output_manager = structured_output_manager
@ -64,11 +63,22 @@ class Scheduler(SchedulerInterface):
self.scheduler_config.max_num_batched_tokens self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len self.max_model_len = self.scheduler_config.max_model_len
# Create KVConnector for the Scheduler. Note that each Worker
# will have a corresponding KVConnector with Role=WORKER.
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None
if self.vllm_config.kv_transfer_config is not None:
self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0
# Create the KV cache manager. # Create the KV cache manager.
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
enable_caching=cache_config.enable_prefix_caching, enable_caching=self.cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo, caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
log_stats=self.log_stats) log_stats=self.log_stats)
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
@ -99,8 +109,8 @@ class Scheduler(SchedulerInterface):
# This can be changed when we make encoder cache for embedding caching # This can be changed when we make encoder cache for embedding caching
# across requests. # across requests.
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config, model_config=vllm_config.model_config,
scheduler_config=scheduler_config, scheduler_config=vllm_config.scheduler_config,
mm_registry=mm_registry, mm_registry=mm_registry,
) )
@ -115,6 +125,7 @@ class Scheduler(SchedulerInterface):
cache_size=encoder_cache_size) cache_size=encoder_cache_size)
self.num_lookahead_tokens = 0 self.num_lookahead_tokens = 0
speculative_config = vllm_config.speculative_config
if speculative_config and speculative_config.method == "eagle": if speculative_config and speculative_config.method == "eagle":
self.num_lookahead_tokens = \ self.num_lookahead_tokens = \
speculative_config.num_speculative_tokens speculative_config.num_speculative_tokens
@ -304,6 +315,16 @@ class Scheduler(SchedulerInterface):
# Get already-cached tokens. # Get already-cached tokens.
computed_blocks, num_computed_tokens = \ computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(request) self.kv_cache_manager.get_computed_blocks(request)
# Get externally-cached tokens if using a KVConnector.
num_external_tokens = (
0 if self.connector is None else
self.connector.get_num_new_matched_tokens(
request, num_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens += num_external_tokens
# Number of tokens to be scheduled. # Number of tokens to be scheduled.
# We use `request.num_tokens` instead of # We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests, # `request.num_prompt_tokens` to consider the resumed requests,
@ -330,11 +351,21 @@ class Scheduler(SchedulerInterface):
new_encoder_budget = encoder_budget new_encoder_budget = encoder_budget
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks) request, num_new_tokens + num_external_tokens,
computed_blocks)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
break break
# KVConnector: update internal state after allocation.
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
num_external_tokens,
)
self.waiting.popleft() self.waiting.popleft()
if request.use_structured_output: if request.use_structured_output:
structured_output_request_ids[ structured_output_request_ids[
@ -443,6 +474,14 @@ class Scheduler(SchedulerInterface):
grammar_bitmask=grammar_bitmask, grammar_bitmask=grammar_bitmask,
) )
# NOTE(Kuntai): this function is designed for multiple purposes:
# 1. Plan the KV cache store
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
# Advance the number of computed tokens for the request AFTER # Advance the number of computed tokens for the request AFTER
# the request is scheduled. # the request is scheduled.
# 1. The scheduler_output of the current step has to include the # 1. The scheduler_output of the current step has to include the
@ -508,6 +547,9 @@ class Scheduler(SchedulerInterface):
If an encoder input cannot be scheduled due to cache or budget If an encoder input cannot be scheduled due to cache or budget
limitations, the method adjusts `num_new_tokens` to schedule only the limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input. decoder tokens up to just before the unschedulable encoder input.
Note that num_computed_tokens includes both locally cached
blocks and externally cached blocks (via KVConnector).
""" """
encoder_inputs_to_schedule: list[int] = [] encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions mm_positions = request.mm_positions

View File

@ -92,12 +92,8 @@ class EngineCore:
vllm_config.scheduler_config.scheduler_cls) vllm_config.scheduler_config.scheduler_cls)
self.scheduler: SchedulerInterface = Scheduler( self.scheduler: SchedulerInterface = Scheduler(
scheduler_config=vllm_config.scheduler_config, vllm_config=vllm_config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
speculative_config=vllm_config.speculative_config,
structured_output_manager=self.structured_output_manager, structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1, > 1,

View File

@ -13,6 +13,8 @@ import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
@ -987,6 +989,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]: ) -> Union[ModelRunnerOutput, torch.Tensor]:
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata(
scheduler_output.kv_connector_metadata)
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOutput if there's no work to do. # Return empty ModelRunnerOutput if there's no work to do.
@ -1228,6 +1235,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# in the next step. # in the next step.
del draft_probs del draft_probs
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
return ModelRunnerOutput( return ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,

View File

@ -9,11 +9,12 @@ import torch.distributed
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -110,7 +111,7 @@ class Worker(WorkerBase):
raise RuntimeError( raise RuntimeError(
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank, init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method, self.distributed_init_method,
self.local_rank) self.local_rank)
# Set random seed. # Set random seed.
@ -285,12 +286,13 @@ class Worker(WorkerBase):
def init_worker_distributed_environment( def init_worker_distributed_environment(
parallel_config: ParallelConfig, vllm_config: VllmConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
@ -299,6 +301,8 @@ def init_worker_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.

View File

@ -23,7 +23,8 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture) graph_capture)
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context

View File

@ -10,10 +10,10 @@ import torch.distributed
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_kv_transfer_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed