[P/D][V1] KV Connector API V1 (#15960)
Signed-off-by: ApostaC <yihua98@uchicago.edu> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: remi <remi@mistral.ai> Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
parent
0377b8310b
commit
3408e47159
@ -0,0 +1,36 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.config import KVTransferConfig
|
||||||
|
|
||||||
|
# Read prompts from output.txt
|
||||||
|
prompts = []
|
||||||
|
try:
|
||||||
|
with open("output.txt") as f:
|
||||||
|
for line in f:
|
||||||
|
prompts.append(line.strip())
|
||||||
|
print(f"Loaded {len(prompts)} prompts from output.txt")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("Error: output.txt file not found")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
enforce_eager=True,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
max_num_batched_tokens=64,
|
||||||
|
max_num_seqs=16,
|
||||||
|
kv_transfer_config=KVTransferConfig.from_cli(
|
||||||
|
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",'
|
||||||
|
'"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}'
|
||||||
|
)) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||||
|
|
||||||
|
# 1ST generation (prefill instance)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
@ -0,0 +1,43 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.config import KVTransferConfig
|
||||||
|
|
||||||
|
context = "Hi " * 1000
|
||||||
|
context2 = "Hey " * 500
|
||||||
|
prompts = [
|
||||||
|
context + "Hello, my name is",
|
||||||
|
context + "The capital of France is",
|
||||||
|
context2 + "Your name is",
|
||||||
|
context2 + "The capital of China is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||||
|
|
||||||
|
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
enforce_eager=True,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
kv_transfer_config=KVTransferConfig.from_cli(
|
||||||
|
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", '
|
||||||
|
'"kv_connector_extra_config": '
|
||||||
|
'{"shared_storage_path": "local_storage"}}')
|
||||||
|
) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||||
|
|
||||||
|
# 1ST generation (prefill instance)
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_prompts = []
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
new_prompts.append(prompt + generated_text)
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
# Write new_prompts to output.txt
|
||||||
|
with open("output.txt", "w") as f:
|
||||||
|
for prompt in new_prompts:
|
||||||
|
f.write(prompt + "\n")
|
||||||
|
print(f"Saved {len(new_prompts)} prompts to output.txt")
|
@ -0,0 +1,5 @@
|
|||||||
|
rm -rf local_storage/
|
||||||
|
rm output.txt
|
||||||
|
|
||||||
|
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
|
||||||
|
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
|
@ -1,10 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||||
|
SchedulerConfig, VllmConfig)
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -25,6 +27,9 @@ def create_scheduler(
|
|||||||
enable_prefix_caching: Optional[bool] = None,
|
enable_prefix_caching: Optional[bool] = None,
|
||||||
long_prefill_token_threshold: int = 0,
|
long_prefill_token_threshold: int = 0,
|
||||||
disable_chunked_mm_input: bool = False,
|
disable_chunked_mm_input: bool = False,
|
||||||
|
use_kv_connector: bool = False,
|
||||||
|
num_blocks: int = 10000,
|
||||||
|
block_size: int = 16,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
'''Create scheduler under test.
|
'''Create scheduler under test.
|
||||||
|
|
||||||
@ -60,31 +65,36 @@ def create_scheduler(
|
|||||||
'enable_prefix_caching': enable_prefix_caching
|
'enable_prefix_caching': enable_prefix_caching
|
||||||
})
|
})
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=16,
|
block_size=block_size,
|
||||||
gpu_memory_utilization=0.9,
|
gpu_memory_utilization=0.9,
|
||||||
swap_space=0,
|
swap_space=0,
|
||||||
cache_dtype="auto",
|
cache_dtype="auto",
|
||||||
**kwargs_cache,
|
**kwargs_cache,
|
||||||
)
|
)
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="SharedStorageConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||||
|
) if use_kv_connector else None
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
)
|
)
|
||||||
kv_cache_config = KVCacheConfig(
|
kv_cache_config = KVCacheConfig(
|
||||||
num_blocks=10000, # A large number of blocks to hold all requests
|
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||||
tensors={},
|
tensors={},
|
||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(['layer'],
|
KVCacheGroupSpec(['layer'],
|
||||||
FullAttentionSpec(16, 1, 1, torch.float32, False))
|
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||||
|
False))
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
cache_config.num_gpu_blocks = 10000
|
cache_config.num_gpu_blocks = num_blocks
|
||||||
return Scheduler(
|
return Scheduler(
|
||||||
scheduler_config,
|
vllm_config=vllm_config,
|
||||||
model_config,
|
|
||||||
cache_config,
|
|
||||||
lora_config=None,
|
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
log_stats=True,
|
log_stats=True,
|
||||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
@ -761,3 +771,390 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|||||||
stats = scheduler_stats.spec_decoding_stats
|
stats = scheduler_stats.spec_decoding_stats
|
||||||
assert stats.num_draft_tokens == expected[0]
|
assert stats.num_draft_tokens == expected[0]
|
||||||
assert stats.num_accepted_tokens == expected[1]
|
assert stats.num_accepted_tokens == expected[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_right_scheduler_output(
|
||||||
|
output: SchedulerOutput,
|
||||||
|
num_requests: int,
|
||||||
|
expected_num_scheduled_tokens: int,
|
||||||
|
):
|
||||||
|
"""Check if SchedulerOutput is correct after remote KV cache hit."""
|
||||||
|
|
||||||
|
# We should inject the kv_connector_metadata.
|
||||||
|
assert len(output.kv_connector_metadata.requests) == num_requests
|
||||||
|
|
||||||
|
# Only num_tokens - matched_num_new_tokens should be scheduled.
|
||||||
|
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
|
||||||
|
assert num_scheduled_tokens == expected_num_scheduled_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_right_kv_cache_manager(
|
||||||
|
scheduler: Scheduler,
|
||||||
|
req_ids: list[str],
|
||||||
|
num_tokens: int,
|
||||||
|
block_size: int,
|
||||||
|
num_requests: int,
|
||||||
|
num_total_blocks: int,
|
||||||
|
):
|
||||||
|
"""Check whether KVCacheManager is correct after allocate."""
|
||||||
|
|
||||||
|
# Make sure the request stats are right.
|
||||||
|
EXPECTED_ACTUAL_BLOCKS = num_tokens // block_size
|
||||||
|
EXPECTED_TOTAL_BLOCKS = (EXPECTED_ACTUAL_BLOCKS +
|
||||||
|
scheduler.kv_cache_manager.num_preallocate_blocks)
|
||||||
|
for req_id in req_ids:
|
||||||
|
blocks = scheduler.kv_cache_manager.req_to_blocks[req_id]
|
||||||
|
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
|
||||||
|
assert (scheduler.kv_cache_manager.num_cached_block[req_id] ==
|
||||||
|
EXPECTED_ACTUAL_BLOCKS)
|
||||||
|
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
|
||||||
|
assert len(hashes) == EXPECTED_ACTUAL_BLOCKS
|
||||||
|
|
||||||
|
# Make sure we actually touched all the blocks.
|
||||||
|
BLOCKS_PER_REQ = (num_tokens / block_size +
|
||||||
|
scheduler.kv_cache_manager.num_preallocate_blocks)
|
||||||
|
assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() ==
|
||||||
|
num_total_blocks - num_requests * BLOCKS_PER_REQ)
|
||||||
|
|
||||||
|
|
||||||
|
def _step_until_done(
|
||||||
|
scheduler: Scheduler,
|
||||||
|
output: SchedulerOutput,
|
||||||
|
model_runner_output: ModelRunnerOutput,
|
||||||
|
):
|
||||||
|
"""Loop over schedule(), update_from_output() until finished."""
|
||||||
|
|
||||||
|
all_finished = False
|
||||||
|
_ = scheduler.update_from_output(output, model_runner_output)
|
||||||
|
while not all_finished:
|
||||||
|
# Schedule + a few iterations until stopping.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running)
|
||||||
|
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
|
||||||
|
# We should be in the decode phase now.
|
||||||
|
assert num_scheduled_tokens == 1
|
||||||
|
assert len(output.kv_connector_metadata.requests) == 0
|
||||||
|
ecos = scheduler.update_from_output(output, model_runner_output)
|
||||||
|
all_done = True
|
||||||
|
for eco in ecos.outputs:
|
||||||
|
if eco.finish_reason is None:
|
||||||
|
all_done = False
|
||||||
|
all_finished = all_done
|
||||||
|
|
||||||
|
|
||||||
|
def test_kv_connector_basic():
|
||||||
|
"""
|
||||||
|
Test whether Scheduler with KVConnector schedules tokens, allocates
|
||||||
|
memory, and cleans up requests as expected under normal operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Setup Scheduler.
|
||||||
|
scheduler = create_scheduler(
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
use_kv_connector=True,
|
||||||
|
)
|
||||||
|
NUM_TOTAL_BLOCKS = (
|
||||||
|
scheduler.kv_cache_manager.block_pool.get_num_free_blocks())
|
||||||
|
BLOCK_SIZE = scheduler.cache_config.block_size
|
||||||
|
|
||||||
|
# Mock External Cache Hit.
|
||||||
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
||||||
|
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
||||||
|
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
||||||
|
NUM_MATCHED_NEW_TOKENS)
|
||||||
|
|
||||||
|
######################################################
|
||||||
|
# FIRST SET OF REQUESTS - External Hit Only
|
||||||
|
NUM_REQUESTS = 2
|
||||||
|
NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2
|
||||||
|
MAX_TOKENS = 3
|
||||||
|
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
max_tokens=MAX_TOKENS)
|
||||||
|
req_ids = []
|
||||||
|
req_to_index = {}
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.add_request(request)
|
||||||
|
req_ids.append(request.request_id)
|
||||||
|
req_to_index[request.request_id] = i
|
||||||
|
|
||||||
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[1000]] * len(req_ids),
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure ScheduleOutput is correct.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(
|
||||||
|
output=output,
|
||||||
|
num_requests=NUM_REQUESTS,
|
||||||
|
# Just the incremental tokens should be scheduled.
|
||||||
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure KVCacheManager is correct.
|
||||||
|
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
|
||||||
|
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
|
||||||
|
|
||||||
|
# Continue Generation until done.
|
||||||
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
||||||
|
_ = scheduler.schedule()
|
||||||
|
# Confirm we clean up the memory properly.
|
||||||
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||||
|
== NUM_TOTAL_BLOCKS
|
||||||
|
|
||||||
|
######################################################
|
||||||
|
# SECOND SET OF REQUESTS - Local And External Hit
|
||||||
|
NUM_TOKENS_PREFIX = NUM_TOKENS
|
||||||
|
# We will get a local prefix cache hit for the first
|
||||||
|
# NUM_TOKENS_PREFIX tokens since they are used above.
|
||||||
|
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
|
||||||
|
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
max_tokens=MAX_TOKENS)
|
||||||
|
req_ids = []
|
||||||
|
req_to_index = {}
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.add_request(request)
|
||||||
|
req_ids.append(request.request_id)
|
||||||
|
req_to_index[request.request_id] = i
|
||||||
|
|
||||||
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[1000]] * len(req_ids),
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# We should get a local cache hit of NUM_TOKENS_PREFIX and
|
||||||
|
# a remote KV cache hit of NUM_MATCHED_NEW_TOKENS.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(
|
||||||
|
output=output,
|
||||||
|
num_requests=NUM_REQUESTS,
|
||||||
|
# Just the incremental tokens after local + remote cache hit.
|
||||||
|
expected_num_scheduled_tokens=(NUM_TOKENS - NUM_TOKENS_PREFIX -
|
||||||
|
NUM_MATCHED_NEW_TOKENS))
|
||||||
|
|
||||||
|
# Ensure KVCacheManager is correct.
|
||||||
|
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
|
||||||
|
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
|
||||||
|
|
||||||
|
# Continue Generation until done.
|
||||||
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
||||||
|
_ = scheduler.schedule()
|
||||||
|
# Confirm we clean up the memory properly.
|
||||||
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||||
|
== NUM_TOTAL_BLOCKS
|
||||||
|
|
||||||
|
|
||||||
|
def test_kv_connector_unable_to_allocate():
|
||||||
|
"""
|
||||||
|
Test whether scheduler with KVConnector is able to handle
|
||||||
|
unable to allocate (run out of blocks in allocate_slots().
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Setup Scheduler With Mock External Cache Hit.
|
||||||
|
BLOCK_SIZE = 4
|
||||||
|
NUM_BLOCKS = 10
|
||||||
|
scheduler = create_scheduler(
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
use_kv_connector=True,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
num_blocks=NUM_BLOCKS,
|
||||||
|
)
|
||||||
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
||||||
|
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
||||||
|
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
||||||
|
NUM_MATCHED_NEW_TOKENS)
|
||||||
|
|
||||||
|
# Create two requests. The second request will not be able to
|
||||||
|
# allocate slots because it will not have enough blocks.
|
||||||
|
NUM_REQUESTS = 2
|
||||||
|
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
|
||||||
|
MAX_TOKENS = 2
|
||||||
|
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
max_tokens=MAX_TOKENS)
|
||||||
|
req_ids = []
|
||||||
|
req_to_index = {}
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.add_request(request)
|
||||||
|
req_ids.append(request.request_id)
|
||||||
|
req_to_index[request.request_id] = i
|
||||||
|
|
||||||
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[1000]] * len(req_ids),
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Just one request should be running.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(output,
|
||||||
|
num_requests=1,
|
||||||
|
expected_num_scheduled_tokens=NUM_TOKENS -
|
||||||
|
NUM_MATCHED_NEW_TOKENS)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
|
# All memory should be freed, with one request waiting.
|
||||||
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
||||||
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||||
|
== NUM_BLOCKS - 1
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
|
# Just one request should be running.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(output,
|
||||||
|
num_requests=1,
|
||||||
|
expected_num_scheduled_tokens=NUM_TOKENS -
|
||||||
|
NUM_MATCHED_NEW_TOKENS)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
|
||||||
|
# All memory should be freed, with no requests waiting / running.
|
||||||
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
||||||
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||||
|
== NUM_BLOCKS - 1
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_kv_connector_handles_preemption():
|
||||||
|
"""
|
||||||
|
Test whether scheduler with KVConnector is able to handle
|
||||||
|
unable to allocate (run out of blocks in allocate_slots().
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Setup Scheduler With Mock External Cache Hit.
|
||||||
|
BLOCK_SIZE = 2
|
||||||
|
# NOTE: there is 1 null block, so this is 6 blocks.
|
||||||
|
NUM_BLOCKS = 7
|
||||||
|
scheduler = create_scheduler(
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
use_kv_connector=True,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
num_blocks=NUM_BLOCKS,
|
||||||
|
)
|
||||||
|
scheduler.kv_cache_manager.num_preallocate_blocks = 0
|
||||||
|
|
||||||
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
|
||||||
|
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
||||||
|
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
||||||
|
NUM_MATCHED_NEW_TOKENS)
|
||||||
|
|
||||||
|
# Create two requests.
|
||||||
|
# Both can be scheduled at first, but the second request
|
||||||
|
# will be preempted and re-scheduled.
|
||||||
|
NUM_REQUESTS = 2
|
||||||
|
NUM_TOKENS = BLOCK_SIZE * 2 + 1
|
||||||
|
MAX_TOKENS = BLOCK_SIZE * 2
|
||||||
|
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
max_tokens=MAX_TOKENS)
|
||||||
|
req_ids = []
|
||||||
|
req_to_index = {}
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.add_request(request)
|
||||||
|
req_ids.append(request.request_id)
|
||||||
|
req_to_index[request.request_id] = i
|
||||||
|
|
||||||
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[1000]] * len(req_ids),
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# All can be scheduled - 1st token.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(
|
||||||
|
output,
|
||||||
|
# 2 remote kv cache hits.
|
||||||
|
num_requests=2,
|
||||||
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS)
|
||||||
|
assert len(scheduler.running) == 2
|
||||||
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
||||||
|
|
||||||
|
# All can be scheduled - 2nd token.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(
|
||||||
|
output,
|
||||||
|
# no connector_metadata
|
||||||
|
num_requests=0,
|
||||||
|
expected_num_scheduled_tokens=1)
|
||||||
|
assert len(scheduler.running) == 2
|
||||||
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
||||||
|
|
||||||
|
# This will generate a new block and cause a preemption - 3rd token.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(
|
||||||
|
output,
|
||||||
|
# no connector_metadata
|
||||||
|
num_requests=0,
|
||||||
|
expected_num_scheduled_tokens=1)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
|
# Only 1 can be scheduled - 4th (and last token).
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(
|
||||||
|
output,
|
||||||
|
# no connector_metadata
|
||||||
|
num_requests=0,
|
||||||
|
expected_num_scheduled_tokens=1)
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
# All memory should be freed since nothing is running.
|
||||||
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||||
|
== NUM_BLOCKS - 1
|
||||||
|
|
||||||
|
# Restarts the preempted request - generate 3rd token.
|
||||||
|
# This will have a local and remote cache hit.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(
|
||||||
|
output,
|
||||||
|
# 1 remote kv_cache hit!
|
||||||
|
num_requests=1,
|
||||||
|
# Only 1 block was preempted and there is a single
|
||||||
|
# remote hit. So only single new token is scheduled.
|
||||||
|
expected_num_scheduled_tokens=1,
|
||||||
|
)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
|
||||||
|
# Only 1 can be scheduled - 4th (and last token).
|
||||||
|
output = scheduler.schedule()
|
||||||
|
_assert_right_scheduler_output(
|
||||||
|
output,
|
||||||
|
# no connector_metadata
|
||||||
|
num_requests=0,
|
||||||
|
expected_num_scheduled_tokens=1)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
# All memory should be freed since nothing is running.
|
||||||
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||||
|
== NUM_BLOCKS - 1
|
||||||
|
@ -10,6 +10,9 @@ import vllm.envs as envs
|
|||||||
from vllm.attention import AttentionType
|
from vllm.attention import AttentionType
|
||||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||||
from vllm.config import CacheConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group,
|
||||||
|
is_v1_kv_transfer_group)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -329,17 +332,54 @@ class MultiHeadAttention(nn.Module):
|
|||||||
return out.reshape(bsz, q_len, -1)
|
return out.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||||
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
return
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
connector.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_save_kv_layer_to_connector(
|
||||||
|
layer_name: str,
|
||||||
|
kv_cache_layer: List[torch.Tensor],
|
||||||
|
):
|
||||||
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
return
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
||||||
|
|
||||||
|
|
||||||
def unified_attention(
|
def unified_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||||
|
attn_metadata)
|
||||||
|
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def unified_attention_fake(
|
def unified_attention_fake(
|
||||||
@ -367,6 +407,7 @@ def unified_attention_with_output(
|
|||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
@ -379,6 +420,8 @@ def unified_attention_with_output(
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
output=output)
|
output=output)
|
||||||
|
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
|
||||||
|
|
||||||
def unified_attention_with_output_fake(
|
def unified_attention_with_output_fake(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
|
@ -0,0 +1,11 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||||
|
ensure_kv_transfer_initialized, get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group, is_v1_kv_transfer_group)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_kv_transfer_group", "has_kv_transfer_group",
|
||||||
|
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
|
||||||
|
"KVConnectorBaseType"
|
||||||
|
]
|
@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -121,3 +122,6 @@ class KVConnectorBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
|
||||||
|
@ -3,14 +3,22 @@
|
|||||||
import importlib
|
import importlib
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, Type
|
from typing import TYPE_CHECKING, Callable, Dict, Type
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||||
|
KVConnectorRole)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .base import KVConnectorBase
|
from .base import KVConnectorBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class KVConnectorFactory:
|
class KVConnectorFactory:
|
||||||
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
|
_registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_connector(cls, name: str, module_path: str,
|
def register_connector(cls, name: str, module_path: str,
|
||||||
@ -19,22 +27,51 @@ class KVConnectorFactory:
|
|||||||
if name in cls._registry:
|
if name in cls._registry:
|
||||||
raise ValueError(f"Connector '{name}' is already registered.")
|
raise ValueError(f"Connector '{name}' is already registered.")
|
||||||
|
|
||||||
def loader() -> Type[KVConnectorBase]:
|
def loader() -> Type[KVConnectorBaseType]:
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
cls._registry[name] = loader
|
cls._registry[name] = loader
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_connector(cls, rank: int, local_rank: int,
|
def create_connector_v0(cls, rank: int, local_rank: int,
|
||||||
config: "VllmConfig") -> KVConnectorBase:
|
config: "VllmConfig") -> KVConnectorBase:
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
raise ValueError("Attempting to initialize a V0 Connector, "
|
||||||
|
f"but found {envs.VLLM_USE_V1=}")
|
||||||
|
|
||||||
connector_name = config.kv_transfer_config.kv_connector
|
connector_name = config.kv_transfer_config.kv_connector
|
||||||
if connector_name not in cls._registry:
|
if connector_name not in cls._registry:
|
||||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||||
|
|
||||||
connector_cls = cls._registry[connector_name]()
|
connector_cls = cls._registry[connector_name]()
|
||||||
|
assert issubclass(connector_cls, KVConnectorBase)
|
||||||
return connector_cls(rank, local_rank, config)
|
return connector_cls(rank, local_rank, config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_connector_v1(
|
||||||
|
cls,
|
||||||
|
config: "VllmConfig",
|
||||||
|
role: KVConnectorRole,
|
||||||
|
) -> KVConnectorBase_V1:
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
raise ValueError("Attempting to initialize a V1 Connector, "
|
||||||
|
f"but found {envs.VLLM_USE_V1=}")
|
||||||
|
|
||||||
|
connector_name = config.kv_transfer_config.kv_connector
|
||||||
|
connector_cls = cls._registry[connector_name]()
|
||||||
|
assert issubclass(connector_cls, KVConnectorBase_V1)
|
||||||
|
logger.info("Creating v1 connector with name: %s", connector_name)
|
||||||
|
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||||
|
# Scheduler connector:
|
||||||
|
# - Co-locate with scheduler process
|
||||||
|
# - Should only be used inside the Scheduler class
|
||||||
|
# Worker connector:
|
||||||
|
# - Co-locate with worker process
|
||||||
|
# - Should only be used inside the forward context & attention layer
|
||||||
|
# We build separately to enforce strict separation
|
||||||
|
return connector_cls(config, role)
|
||||||
|
|
||||||
|
|
||||||
# Register various connectors here.
|
# Register various connectors here.
|
||||||
# The registration should not be done in each individual file, as we want to
|
# The registration should not be done in each individual file, as we want to
|
||||||
@ -57,4 +94,9 @@ KVConnectorFactory.register_connector(
|
|||||||
KVConnectorFactory.register_connector(
|
KVConnectorFactory.register_connector(
|
||||||
"MooncakeStoreConnector",
|
"MooncakeStoreConnector",
|
||||||
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
|
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
|
||||||
"MooncakeStoreConnector")
|
"MooncakeStoreConnector")
|
||||||
|
|
||||||
|
KVConnectorFactory.register_connector(
|
||||||
|
"SharedStorageConnector",
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
|
||||||
|
"SharedStorageConnector")
|
||||||
|
8
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
8
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorBase_V1, KVConnectorRole)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"KVConnectorRole",
|
||||||
|
"KVConnectorBase_V1",
|
||||||
|
]
|
209
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
209
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
|
||||||
|
communication in vLLM v1
|
||||||
|
|
||||||
|
The class provides the following primitives:
|
||||||
|
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||||
|
is used by the worker-side to load/save KV cache.
|
||||||
|
get_num_new_matched_tokens() - get number of new tokens
|
||||||
|
that exist in the remote KV cache
|
||||||
|
update_state_after_alloc() - update KVConnector state after
|
||||||
|
temporary buffer alloc by the CacheManager.
|
||||||
|
|
||||||
|
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||||
|
the Connector based on the metadata.
|
||||||
|
start_load_kv() - starts loading all KVs (maybe async)
|
||||||
|
wait_for_layer_load() - blocks until layer i load is done
|
||||||
|
|
||||||
|
save_kv_layer() - starts saving KV for layer i (maybe async)
|
||||||
|
wait_for_save() - blocks until all saves are done
|
||||||
|
"""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KVConnectorRole(enum.Enum):
|
||||||
|
# Connector running in the scheduler process
|
||||||
|
SCHEDULER = 0
|
||||||
|
|
||||||
|
# Connector running in the worker process
|
||||||
|
WORKER = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KVConnectorMetadata:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KVConnectorBase_V1(ABC):
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||||
|
logger.warning(
|
||||||
|
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||||
|
"subject to change in the future as we iterate the design.")
|
||||||
|
self._connector_metadata = KVConnectorMetadata()
|
||||||
|
self._vllm_config = vllm_config
|
||||||
|
self._role = role
|
||||||
|
|
||||||
|
@property
|
||||||
|
def role(self) -> KVConnectorRole:
|
||||||
|
return self._role
|
||||||
|
|
||||||
|
def bind_connector_metadata(
|
||||||
|
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||||
|
"""Set the connector metadata from the scheduler.
|
||||||
|
|
||||||
|
This function should be called by the model runner every time
|
||||||
|
before the model execution. The metadata will be used for runtime
|
||||||
|
KV cache loading and saving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_metadata (dict): the connector metadata.
|
||||||
|
"""
|
||||||
|
self._connector_metadata = connector_metadata
|
||||||
|
|
||||||
|
def clear_connector_metadata(self) -> None:
|
||||||
|
"""Clear the connector metadata.
|
||||||
|
|
||||||
|
This function should be called by the model runner every time
|
||||||
|
after the model execution.
|
||||||
|
"""
|
||||||
|
self._connector_metadata = KVConnectorMetadata()
|
||||||
|
|
||||||
|
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||||
|
"""Get the connector metadata.
|
||||||
|
|
||||||
|
This function should only be called inside the connector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConnectorMetadata: the connector metadata.
|
||||||
|
"""
|
||||||
|
return self._connector_metadata
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Worker-side methods
|
||||||
|
# ==============================
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Start loading the KV cache from the connector to vLLM's paged
|
||||||
|
KV buffer. This is called from the forward context before the
|
||||||
|
forward pass to enable async loading during model execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forward_context (ForwardContext): the forward context.
|
||||||
|
**kwargs: additional arguments for the load operation
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The number of elements in kv_caches and layer_names should be
|
||||||
|
the same.
|
||||||
|
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Block until the KV for a specific layer is loaded into vLLM's
|
||||||
|
paged buffer. This is called from within attention layer to ensure
|
||||||
|
async copying from start_load_kv is complete.
|
||||||
|
|
||||||
|
This interface will be useful for layer-by-layer pipelining.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_name: the name of that layer
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
|
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Start saving a layer of KV cache from vLLM's paged buffer
|
||||||
|
to the connector. This is called from within attention layer to
|
||||||
|
enable async copying during execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_name (str): the name of the layer.
|
||||||
|
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||||
|
layer in vLLM.
|
||||||
|
attn_metadata (AttentionMetadata): the attention metadata.
|
||||||
|
**kwargs: additional arguments for the save operation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def wait_for_save(self):
|
||||||
|
"""
|
||||||
|
Block until all the save operations is done. This is called
|
||||||
|
as the forward context exits to ensure that the async saving
|
||||||
|
from save_kv_layer is complete before finishing the forward.
|
||||||
|
|
||||||
|
This prevents overwrites of paged KV buffer before saving done.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Scheduler-side methods
|
||||||
|
# ==============================
|
||||||
|
@abstractmethod
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
num_computed_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get number of new tokens that can be loaded from the
|
||||||
|
external KV cache beyond the num_computed_tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): the request object.
|
||||||
|
num_computed_tokens (int): the number of locally
|
||||||
|
computed tokens for this request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the number of tokens that can be loaded from the
|
||||||
|
external KV cache beyond what is already computed.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
num_external_tokens: int):
|
||||||
|
"""
|
||||||
|
Update KVConnector state after block allocation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_connector_meta(
|
||||||
|
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||||
|
"""
|
||||||
|
Build the connector metadata for this step.
|
||||||
|
|
||||||
|
This function should NOT modify fields in the scheduler_output.
|
||||||
|
Also, calling this function will reset the state of the connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||||
|
"""
|
||||||
|
pass
|
@ -0,0 +1,382 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReqMeta:
|
||||||
|
# Request tokens
|
||||||
|
token_ids: torch.Tensor
|
||||||
|
# Slot mappings, should have the same length as token_ids
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
# Is store or load
|
||||||
|
is_store: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
|
||||||
|
is_store: bool) -> "ReqMeta":
|
||||||
|
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
||||||
|
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
||||||
|
block_ids_tensor = torch.tensor(block_ids)
|
||||||
|
num_blocks = block_ids_tensor.shape[0]
|
||||||
|
block_offsets = torch.arange(0, block_size)
|
||||||
|
slot_mapping = block_offsets.reshape((1, block_size)) + \
|
||||||
|
block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||||
|
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||||
|
return ReqMeta(
|
||||||
|
token_ids=token_ids_tensor,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
is_store=is_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SharedStorageConnectorMetadata(KVConnectorMetadata):
|
||||||
|
requests: list[ReqMeta]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.requests = []
|
||||||
|
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
token_ids: list[int],
|
||||||
|
block_ids: list[int],
|
||||||
|
block_size: int,
|
||||||
|
is_store: bool,
|
||||||
|
) -> None:
|
||||||
|
self.requests.append(
|
||||||
|
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))
|
||||||
|
|
||||||
|
|
||||||
|
class SharedStorageConnector(KVConnectorBase_V1):
|
||||||
|
# NOTE: This is Simple debug implementation of the KV connector.
|
||||||
|
# It save / load the KV cache to / from the disk.
|
||||||
|
# It does extra work which will overwrite the existing prefix-cache in GPU
|
||||||
|
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||||
|
super().__init__(vllm_config=vllm_config, role=role)
|
||||||
|
self._block_size = vllm_config.cache_config.block_size
|
||||||
|
self._requests_need_load: dict[str, Request] = {}
|
||||||
|
transfer_config = vllm_config.kv_transfer_config
|
||||||
|
self._storage_path = transfer_config.get_from_extra_config(
|
||||||
|
"shared_storage_path", "/tmp")
|
||||||
|
logger.info(vllm_config.kv_transfer_config)
|
||||||
|
logger.info("Shared storage path is %s", self._storage_path)
|
||||||
|
|
||||||
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||||
|
paged KV buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forward_context (ForwardContext): the forward context.
|
||||||
|
**kwargs: additional arguments for the load operation
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The number of elements in kv_caches and layer_names should be
|
||||||
|
the same.
|
||||||
|
"""
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
|
||||||
|
def inject_kv_into_layer(
|
||||||
|
dst_kv_cache_layer: torch.Tensor,
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Inject the KV cache into the layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||||
|
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||||
|
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||||
|
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||||
|
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||||
|
otherwise.
|
||||||
|
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||||
|
[num_tokens].
|
||||||
|
"""
|
||||||
|
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||||
|
if isinstance(attn_metadata, MLACommonMetadata):
|
||||||
|
num_pages = dst_kv_cache_layer_shape[0]
|
||||||
|
page_size = dst_kv_cache_layer_shape[1]
|
||||||
|
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||||
|
num_pages * page_size, -1)
|
||||||
|
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||||
|
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||||
|
else:
|
||||||
|
num_pages = dst_kv_cache_layer_shape[1]
|
||||||
|
page_size = dst_kv_cache_layer_shape[2]
|
||||||
|
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||||
|
2, num_pages * page_size, -1)
|
||||||
|
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||||
|
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||||
|
|
||||||
|
# Get the metadata
|
||||||
|
metadata: KVConnectorMetadata = \
|
||||||
|
self._get_connector_metadata()
|
||||||
|
assert isinstance(metadata, SharedStorageConnectorMetadata)
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
logger.warning(
|
||||||
|
"In connector.start_load_kv, but the connector metadata is None"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
logger.warning(
|
||||||
|
"In connector.start_load_kv, but the attn_metadata is None")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load the KV for each request each layer
|
||||||
|
for request in metadata.requests:
|
||||||
|
if request.is_store:
|
||||||
|
continue
|
||||||
|
logger.info("Inject KV cache of %d tokens to the paged memory",
|
||||||
|
len(request.slot_mapping))
|
||||||
|
for layer_name in forward_context.no_compile_layers:
|
||||||
|
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||||
|
kv_cache_layer = attn_layer.kv_cache[\
|
||||||
|
forward_context.virtual_engine]
|
||||||
|
|
||||||
|
filename = self._generate_filename_debug(
|
||||||
|
layer_name, request.token_ids)
|
||||||
|
kv_cache = safetensors.torch.load_file(
|
||||||
|
filename)["kv_cache"].cuda()
|
||||||
|
inject_kv_into_layer(kv_cache_layer, kv_cache,
|
||||||
|
request.slot_mapping)
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||||
|
paged buffer.
|
||||||
|
|
||||||
|
This interface will be useful for layer-by-layer pipelining.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_name: the name of that layer
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
|
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||||
|
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||||
|
to the connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_name (str): the name of the layer.
|
||||||
|
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||||
|
layer in vLLM.
|
||||||
|
attn_metadata (AttentionMetadata): the attention metadata.
|
||||||
|
**kwargs: additional arguments for the save operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_kv_from_layer(
|
||||||
|
layer: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Extract the KV cache from the layer.
|
||||||
|
|
||||||
|
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||||
|
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||||
|
"""
|
||||||
|
if isinstance(attn_metadata, MLACommonMetadata):
|
||||||
|
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||||
|
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
|
||||||
|
...]
|
||||||
|
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||||
|
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
|
||||||
|
...]
|
||||||
|
|
||||||
|
connector_metadata = self._get_connector_metadata()
|
||||||
|
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
|
||||||
|
for request in connector_metadata.requests:
|
||||||
|
if request.is_store:
|
||||||
|
filename = self._generate_filename_debug(
|
||||||
|
layer_name, request.token_ids)
|
||||||
|
kv_cache = extract_kv_from_layer(kv_layer,
|
||||||
|
request.slot_mapping)
|
||||||
|
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
||||||
|
safetensors.torch.save_file(tensors, filename)
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
num_computed_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get number of new tokens that can be loaded from the
|
||||||
|
external KV cache beyond the num_computed_tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): the request object.
|
||||||
|
num_computed_tokens (int): the number of locally
|
||||||
|
computed tokens for this request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the number of tokens that can be loaded from the
|
||||||
|
external KV cache beyond what is already computed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE: in this debug implementation, we assume that the prompt is
|
||||||
|
# cached_prompt + newly_generated_single_token
|
||||||
|
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||||
|
|
||||||
|
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
|
||||||
|
# with the block granularity. And it expects the returned blocks and
|
||||||
|
# num_computed_tokens to also be aligned with the block granularity.
|
||||||
|
if not self._found_match_for_request(request):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
logger.info("External Cache Hit!")
|
||||||
|
|
||||||
|
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||||
|
# the metadata for the worker connector to correctly load the KV
|
||||||
|
num_tokens_to_check = align_to_block_size(
|
||||||
|
len(request.prompt_token_ids) - 1, self._block_size)
|
||||||
|
|
||||||
|
return num_tokens_to_check - num_computed_tokens
|
||||||
|
|
||||||
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
num_external_tokens: int):
|
||||||
|
"""
|
||||||
|
Update KVConnector state after block allocation.
|
||||||
|
|
||||||
|
If blocks were allocated, add to _requests_need_load,
|
||||||
|
such that we load the KVs in the next forward pass.
|
||||||
|
"""
|
||||||
|
if num_external_tokens > 0:
|
||||||
|
self._requests_need_load[request.request_id] = request
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
) -> KVConnectorMetadata:
|
||||||
|
"""Build the connector metadata for this step.
|
||||||
|
|
||||||
|
This function should NOT modify any fields in the scheduler_output.
|
||||||
|
Also, calling this function will reset the state of the connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||||
|
"""
|
||||||
|
meta = SharedStorageConnectorMetadata()
|
||||||
|
|
||||||
|
total_need_load = 0
|
||||||
|
for new_req in scheduler_output.scheduled_new_reqs:
|
||||||
|
if new_req.req_id in self._requests_need_load:
|
||||||
|
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||||
|
block_ids=new_req.block_ids,
|
||||||
|
block_size=self._block_size,
|
||||||
|
is_store=False)
|
||||||
|
total_need_load += 1
|
||||||
|
else:
|
||||||
|
# NOTE: here, we set the store and load being exclusive,
|
||||||
|
# but a single request can have both store and load.
|
||||||
|
# NOTE(rob): for this debug implementation, we only cache
|
||||||
|
# the original prompt tokens.
|
||||||
|
if not self._found_match_for_request(new_req):
|
||||||
|
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||||
|
block_ids=new_req.block_ids,
|
||||||
|
block_size=self._block_size,
|
||||||
|
is_store=True)
|
||||||
|
|
||||||
|
for cached_req in scheduler_output.scheduled_cached_reqs:
|
||||||
|
# NOTE(rob): here we rely on the resumed requests being
|
||||||
|
# the first N requests in the list scheduled_cache_reqs.
|
||||||
|
if not cached_req.resumed_from_preemption:
|
||||||
|
break
|
||||||
|
if cached_req.req_id in self._requests_need_load:
|
||||||
|
# NOTE(rob): cached_req_data does not have the full
|
||||||
|
# list of token ids (only new tokens). So we look it
|
||||||
|
# up in the actual request object.
|
||||||
|
request = self._requests_need_load[cached_req.req_id]
|
||||||
|
total_tokens = (len(cached_req.new_token_ids) +
|
||||||
|
cached_req.num_computed_tokens)
|
||||||
|
token_ids = request.all_token_ids[:total_tokens]
|
||||||
|
|
||||||
|
# NOTE(rob): For resumed req, new_block_ids is all
|
||||||
|
# of the block_ids for the request.
|
||||||
|
block_ids = cached_req.new_block_ids
|
||||||
|
|
||||||
|
meta.add_request(token_ids=token_ids,
|
||||||
|
block_ids=block_ids,
|
||||||
|
block_size=self._block_size,
|
||||||
|
is_store=False)
|
||||||
|
total_need_load += 1
|
||||||
|
|
||||||
|
assert total_need_load == len(self._requests_need_load)
|
||||||
|
self._requests_need_load.clear()
|
||||||
|
return meta
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Helper functions
|
||||||
|
# ==============================
|
||||||
|
|
||||||
|
def _found_match_for_request(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the cache is hit for the request.
|
||||||
|
"""
|
||||||
|
num_tokens_to_check = align_to_block_size(
|
||||||
|
len(request.prompt_token_ids) - 1, self._block_size)
|
||||||
|
foldername = self._generate_foldername_debug(torch.tensor(
|
||||||
|
request.prompt_token_ids)[:num_tokens_to_check],
|
||||||
|
create_folder=False)
|
||||||
|
return os.path.exists(foldername)
|
||||||
|
|
||||||
|
def _generate_foldername_debug(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
create_folder=False,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a folder name based on the hash of the bytes of the input
|
||||||
|
ids.
|
||||||
|
"""
|
||||||
|
input_ids_bytes = input_ids.numpy().tobytes()
|
||||||
|
input_ids_hash = hashlib.md5(input_ids_bytes).hexdigest()
|
||||||
|
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||||
|
if create_folder:
|
||||||
|
os.makedirs(foldername, exist_ok=True)
|
||||||
|
return foldername
|
||||||
|
|
||||||
|
def _generate_filename_debug(
|
||||||
|
self,
|
||||||
|
layer_name: str,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a file name based on the layer name and the hash
|
||||||
|
of the bytes of the input ids.
|
||||||
|
"""
|
||||||
|
foldername = self._generate_foldername_debug(input_ids,
|
||||||
|
create_folder=True)
|
||||||
|
return os.path.join(foldername, f"{layer_name}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def align_to_block_size(num_tokens: int, block_size) -> int:
|
||||||
|
"""Align the number of tokens to the block size.
|
||||||
|
"""
|
||||||
|
return (num_tokens - 1) // block_size * block_size
|
@ -46,7 +46,7 @@ class KVTransferAgent:
|
|||||||
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
|
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
|
||||||
"TransferAgent should only be used when kv_connector is set."
|
"TransferAgent should only be used when kv_connector is set."
|
||||||
|
|
||||||
self.connector = KVConnectorFactory.create_connector(
|
self.connector = KVConnectorFactory.create_connector_v0(
|
||||||
rank, local_rank, config)
|
rank, local_rank, config)
|
||||||
|
|
||||||
def send_kv_caches_and_hidden_states(
|
def send_kv_caches_and_hidden_states(
|
70
vllm/distributed/kv_transfer/kv_transfer_state.py
Normal file
70
vllm/distributed/kv_transfer/kv_transfer_state.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
|
KVConnectorFactory)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||||
|
KVConnectorRole)
|
||||||
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_transfer_group() -> KVConnectorBaseType:
|
||||||
|
assert _KV_CONNECTOR_AGENT is not None, (
|
||||||
|
"disaggregated KV cache transfer parallel group is not initialized")
|
||||||
|
return _KV_CONNECTOR_AGENT
|
||||||
|
|
||||||
|
|
||||||
|
def has_kv_transfer_group() -> bool:
|
||||||
|
return _KV_CONNECTOR_AGENT is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_v1_kv_transfer_group(
|
||||||
|
connector: Optional[KVConnectorBaseType] = None) -> bool:
|
||||||
|
"""Check if the KV connector is the v1 connector.
|
||||||
|
If the argument is None, it will check the global KV connector
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector: The KV connector to check. If None, it will check the
|
||||||
|
global KV connector.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function will no-longer be needed after the v1 KV connector
|
||||||
|
becomes the default.
|
||||||
|
"""
|
||||||
|
if connector is None:
|
||||||
|
connector = _KV_CONNECTOR_AGENT
|
||||||
|
|
||||||
|
if connector is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return isinstance(connector, KVConnectorBase_V1)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
||||||
|
"""
|
||||||
|
Initialize KV cache transfer parallel group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
global _KV_CONNECTOR_AGENT
|
||||||
|
|
||||||
|
if vllm_config.kv_transfer_config is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
|
||||||
|
and _KV_CONNECTOR_AGENT is None):
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
|
||||||
|
config=vllm_config, role=KVConnectorRole.WORKER)
|
||||||
|
else:
|
||||||
|
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
|
||||||
|
rank=get_world_group().rank,
|
||||||
|
local_rank=get_world_group().local_rank,
|
||||||
|
config=vllm_config,
|
||||||
|
)
|
@ -29,15 +29,13 @@ from collections import namedtuple
|
|||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
Union)
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch.distributed import Backend, ProcessGroup
|
from torch.distributed import Backend, ProcessGroup
|
||||||
|
|
||||||
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||||
DeviceCommunicatorBase)
|
DeviceCommunicatorBase)
|
||||||
@ -46,9 +44,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
|
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
|
||||||
supports_custom_op)
|
supports_custom_op)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphCaptureContext:
|
class GraphCaptureContext:
|
||||||
@ -772,14 +767,6 @@ def get_pp_group() -> GroupCoordinator:
|
|||||||
# kept for backward compatibility
|
# kept for backward compatibility
|
||||||
get_pipeline_model_parallel_group = get_pp_group
|
get_pipeline_model_parallel_group = get_pp_group
|
||||||
|
|
||||||
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
|
|
||||||
assert _KV_TRANSFER is not None, (
|
|
||||||
"disaggregated KV cache transfer parallel group is not initialized")
|
|
||||||
return _KV_TRANSFER
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def graph_capture(device: torch.device):
|
def graph_capture(device: torch.device):
|
||||||
@ -962,26 +949,6 @@ def initialize_model_parallel(
|
|||||||
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
|
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
|
||||||
|
|
||||||
|
|
||||||
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
|
||||||
"""
|
|
||||||
Initialize KV cache transfer parallel group.
|
|
||||||
"""
|
|
||||||
|
|
||||||
global _KV_TRANSFER
|
|
||||||
|
|
||||||
if vllm_config.kv_transfer_config is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if all([
|
|
||||||
vllm_config.kv_transfer_config.is_kv_transfer_instance,
|
|
||||||
_KV_TRANSFER is None
|
|
||||||
]):
|
|
||||||
_KV_TRANSFER = kv_transfer.KVTransferAgent(
|
|
||||||
rank=get_world_group().rank,
|
|
||||||
local_rank=get_world_group().local_rank,
|
|
||||||
config=vllm_config)
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size: int,
|
tensor_model_parallel_size: int,
|
||||||
pipeline_model_parallel_size: int,
|
pipeline_model_parallel_size: int,
|
||||||
|
@ -1527,12 +1527,6 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# No Disaggregated Prefill so far.
|
|
||||||
if self.kv_transfer_config != EngineArgs.kv_transfer_config:
|
|
||||||
_raise_or_fallback(feature_name="--kv-transfer-config",
|
|
||||||
recommend_to_remove=False)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# No FlashInfer or XFormers so far.
|
# No FlashInfer or XFormers so far.
|
||||||
V1_BACKENDS = [
|
V1_BACKENDS = [
|
||||||
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
|
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
|
||||||
|
@ -11,6 +11,10 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group,
|
||||||
|
is_v1_kv_transfer_group)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -98,6 +102,17 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
virtual_engine=virtual_engine,
|
virtual_engine=virtual_engine,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
dp_metadata=dp_metadata)
|
dp_metadata=dp_metadata)
|
||||||
|
|
||||||
|
# KVConnector: trigger (possibly async) load before forward.
|
||||||
|
# Each attn layer will block until the reading is complete.
|
||||||
|
trigger_kv_transfer = (attn_metadata is not None
|
||||||
|
and has_kv_transfer_group()
|
||||||
|
and is_v1_kv_transfer_group())
|
||||||
|
if trigger_kv_transfer:
|
||||||
|
kv_connector = get_kv_transfer_group()
|
||||||
|
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||||
|
kv_connector.start_load_kv(_forward_context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -133,4 +148,12 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
logger.info(("Batchsize forward time stats "
|
logger.info(("Batchsize forward time stats "
|
||||||
"(batchsize, count, median_time(ms)): %s"),
|
"(batchsize, count, median_time(ms)): %s"),
|
||||||
forward_stats)
|
forward_stats)
|
||||||
|
|
||||||
|
# KVConnector: each attn layer triggers (possibly async) save.
|
||||||
|
# Ensure all those operations complete before forward() is done.
|
||||||
|
if trigger_kv_transfer:
|
||||||
|
kv_connector = get_kv_transfer_group()
|
||||||
|
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||||
|
kv_connector.wait_for_save()
|
||||||
|
|
||||||
_forward_context = prev_context
|
_forward_context = prev_context
|
||||||
|
@ -171,8 +171,9 @@ class KVCacheManager:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The request to allocate slots.
|
request: The request to allocate slots.
|
||||||
num_tokens: The number of tokens to allocate. Note that this does
|
num_tokens: The number of tokens to allocate, including external
|
||||||
not include the tokens that have already been computed.
|
tokens. Note that this does not include tokens that have
|
||||||
|
already been computed locally (i.e. new_computed_blocks).
|
||||||
new_computed_blocks: A list of new computed blocks just hitting the
|
new_computed_blocks: A list of new computed blocks just hitting the
|
||||||
prefix caching.
|
prefix caching.
|
||||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||||
|
@ -9,6 +9,8 @@ if TYPE_CHECKING:
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorMetadata)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -121,3 +123,6 @@ class SchedulerOutput:
|
|||||||
structured_output_request_ids: dict[str, int]
|
structured_output_request_ids: dict[str, int]
|
||||||
# the bitmask for the whole batch
|
# the bitmask for the whole batch
|
||||||
grammar_bitmask: Optional[npt.NDArray[np.int32]]
|
grammar_bitmask: Optional[npt.NDArray[np.int32]]
|
||||||
|
|
||||||
|
# KV Cache Connector metadata.
|
||||||
|
kv_connector_metadata: Optional[KVConnectorMetadata] = None
|
||||||
|
@ -7,8 +7,10 @@ from collections import deque
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
|
from vllm.config import VllmConfig
|
||||||
SpeculativeConfig)
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
|
KVConnectorFactory)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||||
@ -34,20 +36,17 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scheduler_config: SchedulerConfig,
|
vllm_config: VllmConfig,
|
||||||
model_config: ModelConfig,
|
|
||||||
cache_config: CacheConfig,
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
structured_output_manager: StructuredOutputManager,
|
structured_output_manager: StructuredOutputManager,
|
||||||
speculative_config: SpeculativeConfig = None,
|
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
include_finished_set: bool = False,
|
include_finished_set: bool = False,
|
||||||
log_stats: bool = False,
|
log_stats: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scheduler_config = scheduler_config
|
self.vllm_config = vllm_config
|
||||||
self.cache_config = cache_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
self.lora_config = lora_config
|
self.cache_config = vllm_config.cache_config
|
||||||
|
self.lora_config = vllm_config.lora_config
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
self.structured_output_manager = structured_output_manager
|
self.structured_output_manager = structured_output_manager
|
||||||
@ -64,11 +63,22 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.scheduler_config.max_num_batched_tokens
|
self.scheduler_config.max_num_batched_tokens
|
||||||
self.max_model_len = self.scheduler_config.max_model_len
|
self.max_model_len = self.scheduler_config.max_model_len
|
||||||
|
|
||||||
|
# Create KVConnector for the Scheduler. Note that each Worker
|
||||||
|
# will have a corresponding KVConnector with Role=WORKER.
|
||||||
|
# KV Connector pushes/pull of remote KVs for P/D and offloading.
|
||||||
|
self.connector = None
|
||||||
|
if self.vllm_config.kv_transfer_config is not None:
|
||||||
|
self.connector = KVConnectorFactory.create_connector_v1(
|
||||||
|
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||||
|
|
||||||
|
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||||
|
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||||
|
|
||||||
# Create the KV cache manager.
|
# Create the KV cache manager.
|
||||||
self.kv_cache_manager = KVCacheManager(
|
self.kv_cache_manager = KVCacheManager(
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
enable_caching=cache_config.enable_prefix_caching,
|
enable_caching=self.cache_config.enable_prefix_caching,
|
||||||
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
||||||
log_stats=self.log_stats)
|
log_stats=self.log_stats)
|
||||||
self.block_size = self.cache_config.block_size
|
self.block_size = self.cache_config.block_size
|
||||||
@ -99,8 +109,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
# This can be changed when we make encoder cache for embedding caching
|
# This can be changed when we make encoder cache for embedding caching
|
||||||
# across requests.
|
# across requests.
|
||||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||||
model_config=model_config,
|
model_config=vllm_config.model_config,
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=vllm_config.scheduler_config,
|
||||||
mm_registry=mm_registry,
|
mm_registry=mm_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -115,6 +125,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
cache_size=encoder_cache_size)
|
cache_size=encoder_cache_size)
|
||||||
|
|
||||||
self.num_lookahead_tokens = 0
|
self.num_lookahead_tokens = 0
|
||||||
|
speculative_config = vllm_config.speculative_config
|
||||||
if speculative_config and speculative_config.method == "eagle":
|
if speculative_config and speculative_config.method == "eagle":
|
||||||
self.num_lookahead_tokens = \
|
self.num_lookahead_tokens = \
|
||||||
speculative_config.num_speculative_tokens
|
speculative_config.num_speculative_tokens
|
||||||
@ -304,6 +315,16 @@ class Scheduler(SchedulerInterface):
|
|||||||
# Get already-cached tokens.
|
# Get already-cached tokens.
|
||||||
computed_blocks, num_computed_tokens = \
|
computed_blocks, num_computed_tokens = \
|
||||||
self.kv_cache_manager.get_computed_blocks(request)
|
self.kv_cache_manager.get_computed_blocks(request)
|
||||||
|
|
||||||
|
# Get externally-cached tokens if using a KVConnector.
|
||||||
|
num_external_tokens = (
|
||||||
|
0 if self.connector is None else
|
||||||
|
self.connector.get_num_new_matched_tokens(
|
||||||
|
request, num_computed_tokens))
|
||||||
|
|
||||||
|
# Total computed tokens (local + external).
|
||||||
|
num_computed_tokens += num_external_tokens
|
||||||
|
|
||||||
# Number of tokens to be scheduled.
|
# Number of tokens to be scheduled.
|
||||||
# We use `request.num_tokens` instead of
|
# We use `request.num_tokens` instead of
|
||||||
# `request.num_prompt_tokens` to consider the resumed requests,
|
# `request.num_prompt_tokens` to consider the resumed requests,
|
||||||
@ -330,11 +351,21 @@ class Scheduler(SchedulerInterface):
|
|||||||
new_encoder_budget = encoder_budget
|
new_encoder_budget = encoder_budget
|
||||||
|
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request, num_new_tokens, computed_blocks)
|
request, num_new_tokens + num_external_tokens,
|
||||||
|
computed_blocks)
|
||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
# The request cannot be scheduled.
|
# The request cannot be scheduled.
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# KVConnector: update internal state after allocation.
|
||||||
|
# This information is used to determine if a load is
|
||||||
|
# needed for this request.
|
||||||
|
if self.connector is not None:
|
||||||
|
self.connector.update_state_after_alloc(
|
||||||
|
request,
|
||||||
|
num_external_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
if request.use_structured_output:
|
if request.use_structured_output:
|
||||||
structured_output_request_ids[
|
structured_output_request_ids[
|
||||||
@ -443,6 +474,14 @@ class Scheduler(SchedulerInterface):
|
|||||||
grammar_bitmask=grammar_bitmask,
|
grammar_bitmask=grammar_bitmask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||||
|
# 1. Plan the KV cache store
|
||||||
|
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||||
|
# 3. Clear the internal states of the connector
|
||||||
|
if self.connector is not None:
|
||||||
|
meta = self.connector.build_connector_meta(scheduler_output)
|
||||||
|
scheduler_output.kv_connector_metadata = meta
|
||||||
|
|
||||||
# Advance the number of computed tokens for the request AFTER
|
# Advance the number of computed tokens for the request AFTER
|
||||||
# the request is scheduled.
|
# the request is scheduled.
|
||||||
# 1. The scheduler_output of the current step has to include the
|
# 1. The scheduler_output of the current step has to include the
|
||||||
@ -508,6 +547,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
If an encoder input cannot be scheduled due to cache or budget
|
If an encoder input cannot be scheduled due to cache or budget
|
||||||
limitations, the method adjusts `num_new_tokens` to schedule only the
|
limitations, the method adjusts `num_new_tokens` to schedule only the
|
||||||
decoder tokens up to just before the unschedulable encoder input.
|
decoder tokens up to just before the unschedulable encoder input.
|
||||||
|
|
||||||
|
Note that num_computed_tokens includes both locally cached
|
||||||
|
blocks and externally cached blocks (via KVConnector).
|
||||||
"""
|
"""
|
||||||
encoder_inputs_to_schedule: list[int] = []
|
encoder_inputs_to_schedule: list[int] = []
|
||||||
mm_positions = request.mm_positions
|
mm_positions = request.mm_positions
|
||||||
|
@ -92,12 +92,8 @@ class EngineCore:
|
|||||||
vllm_config.scheduler_config.scheduler_cls)
|
vllm_config.scheduler_config.scheduler_cls)
|
||||||
|
|
||||||
self.scheduler: SchedulerInterface = Scheduler(
|
self.scheduler: SchedulerInterface = Scheduler(
|
||||||
scheduler_config=vllm_config.scheduler_config,
|
vllm_config=vllm_config,
|
||||||
model_config=vllm_config.model_config,
|
|
||||||
cache_config=vllm_config.cache_config,
|
|
||||||
lora_config=vllm_config.lora_config,
|
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
speculative_config=vllm_config.speculative_config,
|
|
||||||
structured_output_manager=self.structured_output_manager,
|
structured_output_manager=self.structured_output_manager,
|
||||||
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
||||||
> 1,
|
> 1,
|
||||||
|
@ -13,6 +13,8 @@ import torch.nn as nn
|
|||||||
from vllm.attention import AttentionType, get_attn_backend
|
from vllm.attention import AttentionType, get_attn_backend
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -987,6 +989,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||||
|
# Update KVConnector with the KVConnector metadata forward().
|
||||||
|
if has_kv_transfer_group():
|
||||||
|
get_kv_transfer_group().bind_connector_metadata(
|
||||||
|
scheduler_output.kv_connector_metadata)
|
||||||
|
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
# Return empty ModelRunnerOutput if there's no work to do.
|
# Return empty ModelRunnerOutput if there's no work to do.
|
||||||
@ -1228,6 +1235,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# in the next step.
|
# in the next step.
|
||||||
del draft_probs
|
del draft_probs
|
||||||
|
|
||||||
|
# Clear KVConnector state after all KVs are generated.
|
||||||
|
if has_kv_transfer_group():
|
||||||
|
get_kv_transfer_group().clear_connector_metadata()
|
||||||
|
|
||||||
return ModelRunnerOutput(
|
return ModelRunnerOutput(
|
||||||
req_ids=self.input_batch.req_ids,
|
req_ids=self.input_batch.req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
|
@ -9,11 +9,12 @@ import torch.distributed
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.device_allocator.cumem import CuMemAllocator
|
from vllm.device_allocator.cumem import CuMemAllocator
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -110,7 +111,7 @@ class Worker(WorkerBase):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not support device type: {self.device_config.device}")
|
f"Not support device type: {self.device_config.device}")
|
||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||||
self.distributed_init_method,
|
self.distributed_init_method,
|
||||||
self.local_rank)
|
self.local_rank)
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
@ -285,12 +286,13 @@ class Worker(WorkerBase):
|
|||||||
|
|
||||||
|
|
||||||
def init_worker_distributed_environment(
|
def init_worker_distributed_environment(
|
||||||
parallel_config: ParallelConfig,
|
vllm_config: VllmConfig,
|
||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: Optional[str] = None,
|
distributed_init_method: Optional[str] = None,
|
||||||
local_rank: int = -1,
|
local_rank: int = -1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||||
|
|
||||||
init_distributed_environment(parallel_config.world_size, rank,
|
init_distributed_environment(parallel_config.world_size, rank,
|
||||||
@ -299,6 +301,8 @@ def init_worker_distributed_environment(
|
|||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|
||||||
|
|
||||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||||
# Check if the GPU supports the dtype.
|
# Check if the GPU supports the dtype.
|
||||||
|
@ -23,7 +23,8 @@ from vllm.attention.backends.abstract import AttentionState
|
|||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.distributed import get_kv_transfer_group, get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
|
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
||||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
||||||
graph_capture)
|
graph_capture)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
|
@ -10,10 +10,10 @@ import torch.distributed
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.device_allocator.cumem import CuMemAllocator
|
from vllm.device_allocator.cumem import CuMemAllocator
|
||||||
from vllm.distributed import (ensure_kv_transfer_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
ensure_model_parallel_initialized,
|
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
|
Loading…
x
Reference in New Issue
Block a user