[Core] Deprecating block manager v1 and make block manager v2 default (#8704)

Removing the block manager v1. This is the initial piece of prefix-caching-centric design. In order to achieve prefix-caching-centric design, we need to simplify the code path so that we only use v2 block manager (which has much higher performance on prefix caching).
This commit is contained in:
Kuntai Du 2024-10-17 11:38:15 -05:00 committed by GitHub
parent 5eda21e773
commit 81ede99ca4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 206 additions and 2109 deletions

View File

@ -77,8 +77,8 @@ steps:
- vllm/
- tests/basic_correctness/test_chunked_prefill
commands:
- VLLM_ATTENTION_BACKEND=XFORMERS VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- label: Core Test # 10min
mirror_hardwares: [amd]
@ -88,11 +88,7 @@ steps:
- vllm/distributed
- tests/core
commands:
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core/test_scheduler.py
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/test_chunked_prefill_scheduler.py
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness.py
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness_sliding_window.py
- pytest -v -s core --ignore=core/block/e2e/test_correctness.py --ignore=core/test_scheduler.py --ignore=core/test_chunked_prefill_scheduler.py --ignore=core/block/e2e/test_correctness.py --ignore=core/block/e2e/test_correctness_sliding_window.py
- pytest -v -s core
- label: Entrypoints Test # 40min
working_dir: "/vllm-workspace/tests"
@ -192,8 +188,7 @@ steps:
- vllm/
- tests/prefix_caching
commands:
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s prefix_caching/test_prefix_caching.py
- pytest -v -s prefix_caching --ignore=prefix_caching/test_prefix_caching.py
- pytest -v -s prefix_caching
- label: Samplers Test # 36min
source_file_dependencies:
@ -217,8 +212,7 @@ steps:
- tests/spec_decode
commands:
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s spec_decode/e2e/test_compatibility.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_compatibility.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
- label: LoRA Test %N # 15min each
mirror_hardwares: [amd]
@ -405,7 +399,7 @@ steps:
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus

View File

@ -38,7 +38,6 @@ def main(args: argparse.Namespace):
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size,
@ -221,9 +220,6 @@ if __name__ == '__main__':
parser.add_argument("--enable-prefix-caching",
action='store_true',
help="Enable automatic prefix caching")
parser.add_argument('--use-v2-block-manager',
action='store_true',
default=EngineArgs.use_v2_block_manager)
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',

View File

@ -33,7 +33,6 @@ from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser
try:
@ -134,7 +133,6 @@ def main(args):
tokenizer_mode='auto',
trust_remote_code=True,
enforce_eager=True,
use_v2_block_manager=args.use_v2_block_manager,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching)
@ -176,10 +174,6 @@ if __name__ == "__main__":
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
default=EngineArgs.use_v2_block_manager,
help='Use BlockSpaceMangerV2')
parser.add_argument('--num-prompts',
type=int,
default=1,

View File

@ -86,7 +86,6 @@ def run_vllm(
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
@ -113,7 +112,6 @@ def run_vllm(
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
)
@ -176,7 +174,6 @@ async def run_vllm_async(
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
@ -204,7 +201,6 @@ async def run_vllm_async(
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
disable_log_requests=True,
@ -341,8 +337,7 @@ def main(args: argparse.Namespace):
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc
args.download_dir, args.load_format, args.disable_async_output_proc
]
if args.async_engine:
@ -471,10 +466,6 @@ if __name__ == "__main__":
type=int,
default=1,
help="Maximum number of forward steps per scheduler call.")
parser.add_argument("--use-v2-block-manager",
action='store_true',
default=EngineArgs.use_v2_block_manager,
help="Enable block manager v2.")
parser.add_argument(
"--enable-prefix-caching",
action='store_true',

View File

@ -16,7 +16,6 @@ def main(args):
enforce_eager=True,
enable_prefix_caching=True,
tensor_parallel_size=args.tensor_parallel_size,
use_v2_block_manager=args.use_v2_block_manager,
)
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
@ -56,8 +55,5 @@ if __name__ == "__main__":
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)

View File

@ -30,7 +30,6 @@ The following code configures vLLM in an offline mode to use speculative decodin
tensor_parallel_size=1,
speculative_model="facebook/opt-125m",
num_speculative_tokens=5,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
@ -104,7 +103,6 @@ matching n-grams in the prompt. For more information read `this thread. <https:/
speculative_model="[ngram]",
num_speculative_tokens=5,
ngram_prompt_lookup_max=4,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
@ -135,7 +133,6 @@ For more information see `this blog <https://pytorch.org/blog/hitchhikers-guide-
tensor_parallel_size=4,
speculative_model="ibm-fms/llama3-70b-accelerator",
speculative_draft_tensor_parallel_size=1,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)

View File

@ -50,8 +50,6 @@ if __name__ == "__main__":
llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf",
speculative_model="ibm-fms/llama-13b-accelerator",
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
)
print("With speculation")

View File

@ -12,7 +12,7 @@ from contextlib import nullcontext
import pytest
from ..models.utils import check_logprobs_close, check_outputs_equal
from ..utils import check_deprecated_block_manager_usage, multi_gpu_test
from ..utils import multi_gpu_test
MODELS = [
"facebook/opt-125m",
@ -20,12 +20,6 @@ MODELS = [
]
@pytest.fixture(scope="module", autouse=True)
def check_deprecated_block_manager():
check_deprecated_block_manager_usage(
'tests/basic_correctness/test_chunked_prefill.py')
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@ -197,7 +191,6 @@ def test_models_with_fp8_kv_cache(
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@ -206,7 +199,6 @@ def test_with_prefix_caching(
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
use_v2_block_manager: bool,
tensor_parallel_size: int,
) -> None:
"""
@ -234,7 +226,6 @@ def test_with_prefix_caching(
enable_chunked_prefill=True,
enable_prefix_caching=enable,
tensor_parallel_size=tensor_parallel_size,
use_v2_block_manager=use_v2_block_manager,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
) as vllm_model:

View File

@ -2,18 +2,11 @@ from itertools import cycle
import pytest
from tests.utils import check_deprecated_block_manager_usage
from vllm import SamplingParams
from .conftest import get_token_ids_from_llm_generator
@pytest.fixture(scope="module", autouse=True)
def check_deprecated_block_manager():
check_deprecated_block_manager_usage(
'tests/core/block/e2e/test_correctness.py')
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@ -28,32 +21,32 @@ def check_deprecated_block_manager():
"num_gpu_blocks_override": 5 * (64 + 1),
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify block manager v2 produces same outputs as block manager v1, even
when there is preemption.
def test_block_manager_with_preemption(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify block manager produces same outputs even when there is preemption.
This constructs two LLM, each with limited number of GPU blocks. The limit
is decided such that as the sequences in the batch grow, sequences must be
preempted and removed from cache.
If the output token ids are equivalent, then we have confidence that the KV
cache is not corrupted in the v2 block manager.
cache is not corrupted.
NOTE: We want a significant number of generated tokens so that any incorrect
KV mapping has time to build up error.
NOTE(Kuntai): Though we have removed block manager v1, this test is still
useful as it asserts the behavior of block manager v2 (now it is called
SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we
keep this test.
"""
output_len = 1024
temperature = 0.0
@ -77,11 +70,9 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
temperature=temperature,
)
print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
@ -104,9 +95,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Lookahead scheduling only supported in v2 block manager.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
@ -218,26 +206,22 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
"max_num_seqs": 10,
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [
{
"use_v2_block_manager": False,
},
{},
])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"use_v2_block_manager": True,
"num_lookahead_slots": 0,
},
{
"use_v2_block_manager": True,
"num_lookahead_slots": 5,
},
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify that chunked prefill works with BlockManagerV2, with and without
lookahead scheduling.
def test_chunked_prefill_block_manager(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify that chunked prefill works with SelfAttnBlockSpaceManager,
with and without lookahead scheduling.
"""
output_len = 32
temperature = 0.0
@ -258,11 +242,11 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
temperature=temperature,
)
print('Getting token ids with BlockManagerV1')
print('Getting token ids with BlockManager')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids with BlockManagerV2')
print('Getting token ids with BlockManager, with lookahead slots.')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
@ -290,32 +274,32 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
"enable_prefix_caching": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
def test_block_manager_prefix_caching_enabled_with_preemption(
baseline_llm_generator, test_llm_generator, batch_size):
"""Verify block manager v2 produces same outputs as block manager v1, even
when there is preemption.
"""Verify block manager produces same outputs even when there is preemption.
This constructs two LLM, each with limited number of GPU blocks. The limit
is decided such that as the sequences in the batch grow, sequences must be
preempted and removed from cache.
If the output token ids are equivalent, then we have confidence that the KV
cache is not corrupted in the v2 block manager.
cache is not corrupted.
NOTE: We want a significant number of generated tokens so that any incorrect
KV mapping has time to build up error.
NOTE(Kuntai): Though we have removed block manager v1, this test is still
useful as it asserts the behavior of block manager v2 (now it is called
SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we
keep this test.
"""
output_len = 1024
temperature = 0.0
@ -339,11 +323,11 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
temperature=temperature,
)
print('Getting token ids from block manager v1')
print('Getting token ids from block manager')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids from block manager v2')
print('Getting token ids from block manager, with preemption')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
@ -366,9 +350,6 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
# Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16,
"num_gpu_blocks_override": 5 * (64 + 1),
# Test APC in v2 block
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
@ -444,9 +425,6 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
"max_model_len": 48,
"block_size": 16,
"num_gpu_blocks_override": 3,
# Test APC in v2 block
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{

View File

@ -3,7 +3,6 @@ from typing import List
import pytest
from tests.utils import check_deprecated_block_manager_usage
from vllm import LLM, SamplingParams
from .conftest import get_text_from_llm_generator
@ -13,12 +12,6 @@ MODEL = "bigcode/starcoder2-3b"
BLOCK_SIZE = 16
@pytest.fixture(scope="module", autouse=True)
def check_deprecated_block_manager():
check_deprecated_block_manager_usage(
'tests/core/block/e2e/test_correctness_sliding_window.py')
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@ -31,10 +24,8 @@ def check_deprecated_block_manager():
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
@ -55,7 +46,6 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
prompts, answer, indices = prep_prompts(batch_size)
print('Getting token ids from block manager v1')
baseline_texts = get_text_from_llm_generator(baseline_llm_generator,
prompts,
sampling_params,
@ -91,10 +81,7 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"enable_chunked_prefill": True
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):

View File

@ -2,7 +2,7 @@ import pytest
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
STR_NOT_IMPL_ENC_DEC_SWA)
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
from vllm.core.block_manager import SelfAttnBlockSpaceManager
from vllm.core.interfaces import AllocStatus
from vllm.sequence import Logprob, SequenceStatus
from vllm.utils import chunk_list
@ -17,7 +17,7 @@ from ..utils import (create_dummy_prompt, create_seq_group,
@pytest.mark.parametrize("watermark", [0.0, 0.5])
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
num_gpu_blocks: int, watermark: float):
block_manager = BlockSpaceManagerV2(
block_manager = SelfAttnBlockSpaceManager(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
@ -63,7 +63,7 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int,
num_seqs_per_group: int,
num_gpu_blocks: int,
watermark: float):
block_manager = BlockSpaceManagerV2(
block_manager = SelfAttnBlockSpaceManager(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
@ -117,16 +117,16 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int,
'''
SWA short for Sliding Window Attention.
At time of writing block manager v2 does not support SWA.
At time of writing block manager does not support SWA.
However even when SWA is implemented for block manager v2,
However even when SWA is implemented for block manager,
there will still most likely be a separate workstream required
to enable SWA for encoder/decoder models.
Therefore this test enforces that one of the following cases
hold true:
1. Block manager v2 does not support SWA at all (true at time of writing)
2. Block manager v2 fails with NotImplementError when SWA is enabled
1. Block manager does not support SWA at all (true at time of writing)
2. Block manager fails with NotImplementError when SWA is enabled
AND a SequenceGroup with an encoder sequence (i.e. in support of an
encoder/decoder model) is passed into can_allocate() as an argument
@ -135,7 +135,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int,
'''
with pytest.raises((NotImplementedError, AssertionError)) as exc_info:
block_manager = BlockSpaceManagerV2(
block_manager = SelfAttnBlockSpaceManager(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
@ -158,7 +158,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int,
block_manager.can_allocate(seq_group)
# Assert that either
# 1. Block manager v2 constructor fails with assertion that sliding window
# 1. Block manager constructor fails with assertion that sliding window
# is not yet supported (most likely near-term outcome at time of
# writing), or
# 2. can_allocate() fails with NotImplementedError due to combination of
@ -177,7 +177,7 @@ def test_can_allocate_encoder_decoder_fails_with_prefix_cache(
block_size: int, num_seqs_per_group: int, num_gpu_blocks: int,
watermark: float):
block_manager = BlockSpaceManagerV2(
block_manager = SelfAttnBlockSpaceManager(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
@ -217,7 +217,7 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
num_gpu_blocks = 1024
watermark = 0.1
block_manager = BlockSpaceManagerV2(
block_manager = SelfAttnBlockSpaceManager(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
@ -269,14 +269,15 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
"""Verify blocks number on src/desc device is correct after swapping in/out
sequence group (not missing or extra blocks).
"""
block_manager = BlockSpaceManagerV2(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
block_manager = SelfAttnBlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
@ -321,11 +322,11 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
can be swapped in/out.
"""
num_cpu_blocks = num_gpu_blocks
block_manager = BlockSpaceManagerV2(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
block_manager = SelfAttnBlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
prompt, seq_group = create_dummy_prompt(
"1", prompt_length=(num_gpu_blocks - 1) * block_size - 1)
prompt.status = SequenceStatus.WAITING
@ -382,11 +383,11 @@ def test_swap_in_infeasible(num_lookahead_slots, enable_caching):
block_size = 8
num_cpu_blocks = 1
num_gpu_blocks = 1
block_manager = BlockSpaceManagerV2(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
block_manager = SelfAttnBlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
prompt_length = block_size - 3
assert prompt_length > 0
prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length)
@ -434,7 +435,7 @@ def test_sliding_window(block_size, prompt_len, num_slots_to_append,
num_gpu_blocks = 1024
watermark = 0.1
block_manager = BlockSpaceManagerV2(
block_manager = SelfAttnBlockSpaceManager(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
@ -474,7 +475,7 @@ def test_sliding_window(block_size, prompt_len, num_slots_to_append,
seq.data.update_num_computed_tokens(prompt_len)
check_used(num_blocks(prompt_len))
# this is how we compute it in BlockSpaceManagerV2.__init__
# this is how we compute it in SelfAttnBlockSpaceManager.__init__
sliding_blocks = (sliding_window // block_size) + 2
# plus one block for null block
sliding_blocks += 1

View File

@ -1,637 +0,0 @@
import time
from collections import defaultdict
from typing import List
import pytest
from vllm import SamplingParams
from vllm.block import PhysicalTokenBlock
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
STR_NOT_IMPL_ENC_DEC_SWA)
from vllm.core.block_manager_v1 import (BlockSpaceManagerV1,
UncachedBlockAllocator)
from vllm.core.interfaces import AllocStatus
from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder
def test_block_allocator_allocate():
block_size = 4
num_cpu_blocks = 4
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
# Allocate all available cpu blocks.
num_free = num_cpu_blocks
assert cpu_allocator.get_num_free_blocks() == num_free
for _ in range(num_cpu_blocks):
block = cpu_allocator.allocate()
num_free -= 1
assert block not in cpu_allocator.free_blocks
assert cpu_allocator.get_num_free_blocks() == num_free
with pytest.raises(ValueError):
cpu_allocator.allocate()
def test_block_allocator_free():
block_size = 4
num_cpu_blocks = 4
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
# Allocate all available cpu blocks.
blocks: List[PhysicalTokenBlock] = []
for _ in range(num_cpu_blocks):
block = cpu_allocator.allocate()
blocks.append(block)
assert block not in cpu_allocator.free_blocks
# Free all allocated cpu blocks.
num_free = 0
assert cpu_allocator.get_num_free_blocks() == num_free
for block in blocks:
cpu_allocator.free(block)
num_free += 1
assert block in cpu_allocator.free_blocks
assert cpu_allocator.get_num_free_blocks() == num_free
with pytest.raises(ValueError):
cpu_allocator.free(block)
def test_allocate():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate same sequence group to all available gpu blocks.
for i in range(num_gpu_blocks):
_, seq_group = create_dummy_prompt(str(i), block_size)
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
block_manager.allocate(seq_group)
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
# Allocate same sequence group to all available gpu blocks.
# Use watermark to reserve one gpu block.
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=1 / num_gpu_blocks)
for i in range(num_gpu_blocks - 1):
_, seq_group = create_dummy_prompt(str(i), block_size)
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
block_manager.allocate(seq_group)
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
def test_allocate_encoder_decoder():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_req_per_seq_group = 2
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate same sequence group to all available gpu blocks.
for i in range(num_gpu_blocks // block_req_per_seq_group):
_, _, seq_group = create_dummy_prompt_encoder_decoder(
str(i),
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
block_manager.allocate(seq_group)
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
# Allocate same sequence group to all available gpu blocks.
# Use watermark to reserve one gpu block.
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=1 / num_gpu_blocks)
for i in range((num_gpu_blocks - 1) // block_req_per_seq_group):
_, _, seq_group = create_dummy_prompt_encoder_decoder(
str(i),
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
block_manager.allocate(seq_group)
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
def test_allocate_encoder_decoder_fails_with_swa():
# SWA short for sliding window attention
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
sliding_window=5) # swa
# Allocate same sequence group to all available gpu blocks.
_, _, seq_group = create_dummy_prompt_encoder_decoder(
"0",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
# Assert that can_allocate() fails due to SWA
with pytest.raises(NotImplementedError) as exc_info:
block_manager.can_allocate(seq_group)
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
# Assert that allocate() fails due to SWA
with pytest.raises(NotImplementedError) as exc_info:
block_manager.allocate(seq_group)
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
def test_allocate_encoder_decoder_fails_with_prefix_caching():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=True) # Prefix cache
# Allocate same sequence group to all available gpu blocks.
_, _, seq_group = create_dummy_prompt_encoder_decoder(
"0",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
# Assert that can_allocate() fails due to prefix caching
with pytest.raises(NotImplementedError) as exc_info:
block_manager.can_allocate(seq_group)
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
# Assert that allocate() fails due to prefix caching
with pytest.raises(NotImplementedError) as exc_info:
block_manager.allocate(seq_group)
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
def test_append_slot_single_seq():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate single seq to gpu block.
prompt, seq_group = create_dummy_prompt("1", block_size)
block_manager.allocate(seq_group)
# Nothing to append. Sequence has no new logical blocks.
assert block_manager.can_append_slots(seq_group)
before_blocks = block_manager.get_num_free_gpu_blocks()
assert not block_manager.append_slots(prompt)
after_blocks = block_manager.get_num_free_gpu_blocks()
assert before_blocks == after_blocks
# Add block_size number of new tokens and append slot.
for i in range(block_size):
token_id = i + 5
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
assert block_manager.can_append_slots(seq_group)
before_blocks = block_manager.get_num_free_gpu_blocks()
assert not block_manager.append_slots(prompt)
after_blocks = block_manager.get_num_free_gpu_blocks()
assert before_blocks - after_blocks == 1
def test_append_slot_cow():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size=block_size,
num_cpu_blocks=num_cpu_blocks,
num_gpu_blocks=num_gpu_blocks,
watermark=0)
# Allocate prompt to gpu block. There is one slot left in the block.
prompt = Sequence(seq_id=1,
inputs={
"prompt": "one two three",
"prompt_token_ids": [1, 2, 3],
},
block_size=block_size)
# Fork the sequence, such that a COW will be required when we append a new
# token id.
child = prompt.fork(new_seq_id=2)
# Allocate space for the sequence group.
seq_group = SequenceGroup(request_id="1",
seqs=[prompt, child],
arrival_time=time.time(),
sampling_params=SamplingParams())
block_manager.allocate(seq_group)
# Fork and append a new token id. We expect a COW to be scheduled.
token_id = 4
child.append_token_id(token_id, {token_id: Logprob(0.0)})
block_manager.fork(prompt, child)
assert block_manager.can_append_slots(seq_group)
before_blocks = block_manager.get_num_free_gpu_blocks()
cows = block_manager.append_slots(child)
assert cows
dict_cows = defaultdict(list)
for src_block, dst_block in cows:
dict_cows[src_block].append(dst_block)
for src_block, dst_blocks in dict_cows.items():
assert src_block not in dst_blocks
after_blocks = block_manager.get_num_free_gpu_blocks()
assert before_blocks - after_blocks == 1
def test_fork():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
prompt, seq_group = create_dummy_prompt("1",
block_size - 1,
block_size=block_size)
block_manager.allocate(seq_group)
# Fork prompt and copy block tables.
child = prompt.fork(2)
block_manager.fork(prompt, child)
assert block_manager.get_block_table(
prompt) == block_manager.get_block_table(child)
token_id = 4
# Append token to child. Block is shared so copy on write occurs.
child.append_token_id(token_id, {token_id: Logprob(0.0)})
block_manager.append_slots(child)
assert block_manager.get_block_table(
prompt) != block_manager.get_block_table(child)
def test_swap():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
token_id = 0
prompt.status = SequenceStatus.RUNNING
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
# Swap seq group from GPU -> CPU.
gpu_blocks = block_manager.get_block_table(prompt)
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
assert [x[0] for x in mapping] == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
prompt.status = SequenceStatus.SWAPPED
# Swap seq group from CPU -> GPU.
cpu_blocks = block_manager.get_block_table(prompt)
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group)
assert [x[0] for x in mapping] == cpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
def test_swap_encoder_decoder():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
decoder_prompt, encoder_prompt, seq_group = \
create_dummy_prompt_encoder_decoder(
"1",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
decoder_prompt.status = SequenceStatus.WAITING
encoder_prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
token_id = 0
decoder_prompt.status = SequenceStatus.RUNNING
decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
# Swap encoder/decoder seq group from GPU -> CPU.
decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt)
cross_gpu_blocks = block_manager.get_cross_block_table(seq_group)
gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
assert [x[0] for x in mapping] == gpu_blocks
#assert list(mapping.keys()) == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
decoder_prompt.status = SequenceStatus.SWAPPED
# Swap encoder/decoder seq group from CPU -> GPU.
decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt)
cross_cpu_blocks = block_manager.get_cross_block_table(seq_group)
cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group)
assert [x[0] for x in mapping] == cpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
def test_free():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
prompt, seq_group = create_dummy_prompt("1", block_size)
block_manager.allocate(seq_group)
# Free allocated seq.
prompt_blocks = len(block_manager.get_block_table(prompt))
before_blocks = block_manager.get_num_free_gpu_blocks()
block_manager.free(prompt)
after_blocks = block_manager.get_num_free_gpu_blocks()
assert after_blocks == before_blocks + prompt_blocks
# Block table for freed seq is deleted.
with pytest.raises(KeyError):
block_manager.get_block_table(prompt)
def test_free_encoder_decoder():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
decoder_prompt, encoder_prompt, seq_group = \
create_dummy_prompt_encoder_decoder(
"1",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
block_manager.allocate(seq_group)
# Free allocated seq.
decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt))
encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group))
prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks
before_blocks = block_manager.get_num_free_gpu_blocks()
block_manager.free(decoder_prompt)
block_manager.free_cross(seq_group)
after_blocks = block_manager.get_num_free_gpu_blocks()
assert after_blocks == before_blocks + prompt_blocks
# Block table for freed encoder & decoder seq's are deleted.
with pytest.raises(KeyError):
block_manager.get_block_table(decoder_prompt)
# Block table for freed encoder & decoder seq's are deleted.
with pytest.raises(KeyError):
block_manager.get_block_table(encoder_prompt)
def test_reset():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate same seq group on all available gpu blocks.
original_blocks = block_manager.get_num_free_gpu_blocks()
for i in range(num_gpu_blocks):
_, seq_group = create_dummy_prompt(str(i), block_size)
block_manager.allocate(seq_group)
assert block_manager.get_num_free_gpu_blocks() == 0
# Resetting block manager frees all allocated blocks.
block_manager.reset()
assert block_manager.get_num_free_gpu_blocks() == original_blocks
def test_reset_encoder_decoder():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_req_per_seq_group = 2
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate same seq group on all available gpu blocks.
original_blocks = block_manager.get_num_free_gpu_blocks()
for i in range(num_gpu_blocks // block_req_per_seq_group):
_, _, seq_group = create_dummy_prompt_encoder_decoder(
f"{i}",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
block_manager.allocate(seq_group)
assert block_manager.get_num_free_gpu_blocks() == 0
# Resetting block manager frees all allocated blocks.
block_manager.reset()
assert block_manager.get_num_free_gpu_blocks() == original_blocks
def test_sliding_window_multi_seq():
"""
Tests that memory allocation and deallocation is handled
correctly with multiple sequences that exceed the sliding
window's capacity.
"""
block_size = 1
num_cpu_blocks = 8
num_gpu_blocks = 8
sliding_window = 2
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
sliding_window=sliding_window,
watermark=0)
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
parent = Sequence(seq_id=1,
inputs={
"prompt": "one two three",
"prompt_token_ids": [0, 1, 2],
},
block_size=block_size)
seq_group = SequenceGroup(request_id="1",
seqs=[parent],
arrival_time=time.time(),
sampling_params=SamplingParams(),
lora_request=None)
block_manager.allocate(seq_group)
# assert the number of blocks allocated is correct
# the parent seq has len 3, but since sliding_window is 2,
# we will use at most 2 blocks
assert block_manager.get_num_free_gpu_blocks(
) == num_gpu_blocks - sliding_window
# Fork prompt and copy block tables.
child = parent.fork(2)
block_manager.fork(parent, child)
# assert the number of blocks allocated is correct
# forking does not increase memory consumption
assert block_manager.get_num_free_gpu_blocks(
) == num_gpu_blocks - sliding_window
# assert both parent and child share all blocks
assert block_manager.get_block_table(
parent) == block_manager.get_block_table(child)
token_id = 4
# Append token to child. Block is shared so copy on write occurs.
child.append_token_id(token_id, {token_id: Logprob(0.0)})
block_manager.append_slots(child)
# assert the number of blocks allocated is correct
# we will use now one block more. Each seq will use 2 blocks,
# but only one can be shared
assert block_manager.get_num_free_gpu_blocks(
) == num_gpu_blocks - sliding_window - 1
token_id = 5
parent.append_token_id(token_id, {token_id: Logprob(0.0)})
block_manager.append_slots(parent)
# assert the number of blocks allocated is correct
# no change, because both sequences are still just sharing one block
assert block_manager.get_num_free_gpu_blocks(
) == num_gpu_blocks - sliding_window - 1
block_table_parent = block_manager.get_block_table(parent)
block_table_child = block_manager.get_block_table(child)
assert block_table_parent != block_table_child
# assert both blocks are sharing the second-last block
assert block_table_parent[-2] == block_table_child[-2]
# now let's clean up...
block_manager.free(parent)
# assert the number of blocks allocated is correct
# We have freed one seq, reducing the ref count of two blocks by one.
# One of the two was only used by the parent seq, so this is now free.
# The child seq still consumes sliding_window blocks
assert block_manager.get_num_free_gpu_blocks(
) == num_gpu_blocks - sliding_window
# free all blocks
block_manager.free(child)
# assert all blocks are free now
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
"""When prefix cache and chunked prefill are enabled, the block manager
should only mark a chunk of blocks as computed instead of all blocks.
"""
block_size = 4
num_cpu_blocks = 0
num_gpu_blocks = 16
block_manager = BlockSpaceManagerV1(block_size,
num_gpu_blocks,
num_cpu_blocks,
watermark=0,
enable_caching=True)
# Set prompt size to have num_gpu_blocks - 1 full blocks.
prompt_length = block_size * num_gpu_blocks - 1
# Allocate (reserve) all blocks.
_, seq_group = create_dummy_prompt("0",
prompt_length,
block_size=block_size)
block_manager.allocate(seq_group)
assert seq_group.seqs[0].n_blocks == num_gpu_blocks
# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
token_chunk_size = int(block_size * 2.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 2
# Actual computed tokens.
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)
# 2nd chunk: Complete 3rd block and additional 4 blocks.
token_chunk_size = int(block_size * 4.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 7

View File

@ -8,7 +8,6 @@ from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler
from vllm.sequence import Logprob, SequenceGroup
from ..utils import check_deprecated_block_manager_usage
from .utils import create_dummy_prompt
@ -28,25 +27,16 @@ def schedule_and_update_computed_tokens(scheduler):
return metas, out
@pytest.fixture(scope="module", autouse=True)
def check_deprecated_block_manager():
check_deprecated_block_manager_usage(
'tests/core/test_chunked_prefill_scheduler.py')
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_simple(use_v2_block_manager: bool):
def test_simple():
"""Verify basic scheduling works."""
block_size = 4
num_seq_group = 4
max_model_len = 16
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
max_num_batched_tokens,
num_seq_group,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
scheduler_config = SchedulerConfig(max_num_batched_tokens,
num_seq_group,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
@ -81,8 +71,7 @@ def test_simple(use_v2_block_manager: bool):
assert len(seq_group_meta) == num_seq_group
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_chunk(use_v2_block_manager: bool):
def test_chunk():
"""Verify prefills are chunked properly."""
block_size = 4
max_seqs = 60
@ -93,7 +82,7 @@ def test_chunk(use_v2_block_manager: bool):
max_seqs,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 32
cache_config.num_gpu_blocks = 32
@ -131,8 +120,7 @@ def test_chunk(use_v2_block_manager: bool):
assert out.num_batched_tokens == 57
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_complex(use_v2_block_manager: bool):
def test_complex():
block_size = 4
max_seqs = 60
max_model_len = 80
@ -142,7 +130,7 @@ def test_complex(use_v2_block_manager: bool):
max_seqs,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 64
cache_config.num_gpu_blocks = 64
@ -201,8 +189,7 @@ def test_complex(use_v2_block_manager: bool):
assert running[2].is_prefill()
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_maximal_decoding(use_v2_block_manager: bool):
def test_maximal_decoding():
"""Verify decoding requests are prioritized."""
block_size = 4
max_seqs = 2
@ -213,7 +200,7 @@ def test_maximal_decoding(use_v2_block_manager: bool):
max_seqs,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
@ -295,8 +282,7 @@ def test_maximal_decoding(use_v2_block_manager: bool):
assert out.num_batched_tokens == 2
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_prompt_limit(use_v2_block_manager: bool):
def test_prompt_limit():
"""Verify max_num_batched_tokens < max_model_len is possible."""
block_size = 4
max_seqs = 32
@ -307,7 +293,7 @@ def test_prompt_limit(use_v2_block_manager: bool):
max_seqs,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 16
@ -330,8 +316,7 @@ def test_prompt_limit(use_v2_block_manager: bool):
assert out.num_batched_tokens == 32
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_prompt_limit_exceed(use_v2_block_manager: bool):
def test_prompt_limit_exceed():
block_size = 4
max_seqs = 64
max_model_len = 32
@ -356,8 +341,7 @@ def test_prompt_limit_exceed(use_v2_block_manager: bool):
assert out.ignored_seq_groups[0] == seq_group
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_swap(use_v2_block_manager: bool):
def test_swap():
"""Verify swapping works with chunked prefill requests"""
block_size = 4
max_seqs = 30
@ -368,7 +352,7 @@ def test_swap(use_v2_block_manager: bool):
max_seqs,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 16
@ -414,8 +398,7 @@ def test_swap(use_v2_block_manager: bool):
assert out.blocks_to_swap_out == []
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool):
def test_running_prefill_prioritized_over_swap():
block_size = 4
max_seqs = 30
max_model_len = 200
@ -425,7 +408,7 @@ def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool):
max_seqs,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 32
cache_config.num_gpu_blocks = 32
@ -508,8 +491,7 @@ def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool):
assert out.blocks_to_swap_out == []
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_chunked_prefill_preempt(use_v2_block_manager: bool):
def test_chunked_prefill_preempt():
"""Verify preempt works with chunked prefill requests"""
block_size = 4
max_seqs = 30
@ -520,7 +502,7 @@ def test_chunked_prefill_preempt(use_v2_block_manager: bool):
max_seqs,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 16
@ -575,8 +557,7 @@ def test_chunked_prefill_preempt(use_v2_block_manager: bool):
assert out.num_batched_tokens == max_num_batched_tokens
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_chunked_prefill_max_seqs(use_v2_block_manager: bool):
def test_chunked_prefill_max_seqs():
block_size = 4
max_seqs = 2
max_model_len = 80
@ -586,7 +567,7 @@ def test_chunked_prefill_max_seqs(use_v2_block_manager: bool):
max_seqs,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 128
cache_config.num_gpu_blocks = 128
@ -629,8 +610,7 @@ def test_chunked_prefill_max_seqs(use_v2_block_manager: bool):
assert not running[1].is_prefill()
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_perfix_caching(use_v2_block_manager: bool):
def test_perfix_caching():
"""Verify allocating full blocks when prefix caching is enabled."""
block_size = 4
max_seqs = 10
@ -641,7 +621,7 @@ def test_perfix_caching(use_v2_block_manager: bool):
max_seqs,
max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size,
1.0,
1,

View File

@ -31,7 +31,6 @@ def test_num_computed_tokens_update(num_scheduler_steps: int,
# Make a vllm engine
runner = VllmRunner(model_name=MODEL,
gpu_memory_utilization=0.7,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
enable_chunked_prefill=enable_chunked_prefill,
enforce_eager=enforce_eager)

View File

@ -3,7 +3,7 @@ from collections import deque
from typing import List, Set, Tuple
from unittest.mock import MagicMock
import pytest
import pytest # noqa
from torch import Use # noqa
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
@ -12,23 +12,18 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroup, SequenceStatus
from ..utils import check_deprecated_block_manager_usage
from .utils import (append_new_token, append_new_token_seq_group,
create_dummy_prompt, get_sequence_groups,
schedule_and_update_computed_tokens)
@pytest.fixture(scope="module", autouse=True)
def check_deprecated_block_manager():
check_deprecated_block_manager_usage(
"tests/core/test_chunked_prefill_scheduler.py")
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_scheduler_add_seq_group(use_v2_block_manager: bool):
def test_scheduler_add_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(
100, 64, 1, use_v2_block_manager=use_v2_block_manager)
100,
64,
1,
)
cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
cache_config.num_cpu_blocks = 4
cache_config.num_gpu_blocks = 4
@ -44,11 +39,13 @@ def test_scheduler_add_seq_group(use_v2_block_manager: bool):
assert scheduler.get_num_unfinished_seq_groups() == i + 1
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_scheduler_abort_seq_group(use_v2_block_manager: bool):
def test_scheduler_abort_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(
100, 64, 1, use_v2_block_manager=use_v2_block_manager)
100,
64,
1,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 4
cache_config.num_gpu_blocks = 4
@ -68,8 +65,7 @@ def test_scheduler_abort_seq_group(use_v2_block_manager: bool):
assert scheduler.get_num_unfinished_seq_groups() == 0
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_scheduler_schedule_simple(use_v2_block_manager: bool):
def test_scheduler_schedule_simple():
block_size = 4
num_seq_group = 4
max_model_len = 16
@ -77,7 +73,7 @@ def test_scheduler_schedule_simple(use_v2_block_manager: bool):
64,
num_seq_group,
max_model_len,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
@ -112,8 +108,7 @@ def test_scheduler_schedule_simple(use_v2_block_manager: bool):
append_new_token(out, 1)
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_scheduler_prefill_prioritized(use_v2_block_manager: bool):
def test_scheduler_prefill_prioritized():
"""Verify running batched tokens are not applied to prefill requests."""
block_size = 4
max_model_len = 30
@ -122,7 +117,7 @@ def test_scheduler_prefill_prioritized(use_v2_block_manager: bool):
max_batched_num_tokens,
2,
max_model_len,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 16
@ -146,12 +141,14 @@ def test_scheduler_prefill_prioritized(use_v2_block_manager: bool):
assert get_sequence_groups(out) == [seq_group_b]
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool):
def test_scheduler_schedule_preempt_abort():
block_size = 4
max_model_len = 16
scheduler_config = SchedulerConfig(
64, 2, max_model_len, use_v2_block_manager=use_v2_block_manager)
64,
2,
max_model_len,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 2
cache_config.num_gpu_blocks = 2
@ -201,8 +198,7 @@ def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool):
assert scheduler.get_num_unfinished_seq_groups() == 1
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_scheduler_max_seqs(use_v2_block_manager: bool):
def test_scheduler_max_seqs():
block_size = 4
num_seq_group = 4
max_seq_group = 2
@ -211,7 +207,7 @@ def test_scheduler_max_seqs(use_v2_block_manager: bool):
64,
max_seq_group,
max_model_len,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
@ -249,15 +245,14 @@ def test_scheduler_max_seqs(use_v2_block_manager: bool):
assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_scheduler_delay_factor(use_v2_block_manager: bool):
def test_scheduler_delay_factor():
block_size = 4
scheduler_config = SchedulerConfig(
100,
64,
16,
delay_factor=0.5,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
@ -294,12 +289,10 @@ def test_scheduler_delay_factor(use_v2_block_manager: bool):
append_new_token(out, 1)
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_swapped_out_prioritized(use_v2_block_manager: bool):
def test_swapped_out_prioritized():
block_size = 4
scheduler = initialize_scheduler(max_num_seqs=6,
block_size=block_size,
use_v2_block_manager=use_v2_block_manager,
num_cpu_blocks=64,
num_gpu_blocks=64)
# best_of=2 * 3 == 6 sequences.
@ -351,7 +344,6 @@ def initialize_scheduler(
max_token_budget=1000,
max_model_len=1000,
lora_config=None,
use_v2_block_manager=False,
block_size=4,
num_cpu_blocks=8,
num_gpu_blocks=8,
@ -361,7 +353,7 @@ def initialize_scheduler(
max_token_budget,
max_num_seqs,
max_model_len,
use_v2_block_manager=use_v2_block_manager)
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = num_cpu_blocks
cache_config.num_gpu_blocks = num_gpu_blocks
@ -386,15 +378,12 @@ def add_token_budget(budget: SchedulingBudget,
budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool):
def test_prefill_schedule_max_prompt_len():
"""
Test prompt longer than max_prompt_len is aborted.
"""
block_size = 4
scheduler = initialize_scheduler(max_model_len=30,
use_v2_block_manager=use_v2_block_manager,
block_size=block_size)
scheduler = initialize_scheduler(max_model_len=30, block_size=block_size)
_, seq_group = create_dummy_prompt("0",
prompt_length=60,
block_size=block_size)
@ -409,14 +398,12 @@ def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool):
assert len(remaining_waiting) == 0
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_prefill_schedule_token_budget(use_v2_block_manager: bool):
def test_prefill_schedule_token_budget():
"""
Test token budget respected.
"""
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=64,
num_gpu_blocks=64)
budget = create_token_budget(token_budget=0)
@ -446,8 +433,7 @@ def test_prefill_schedule_token_budget(use_v2_block_manager: bool):
assert len(remaining_waiting) == 1
# Test when current_batched_tokens respected.
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16)
budget = create_token_budget(token_budget=60)
@ -474,14 +460,12 @@ def test_prefill_schedule_token_budget(use_v2_block_manager: bool):
assert len(remaining_waiting) == 0
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_prefill_schedule_max_seqs(use_v2_block_manager: bool):
def test_prefill_schedule_max_seqs():
"""
Test max seq respected.
"""
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=64,
num_gpu_blocks=64)
budget = create_token_budget(max_num_seqs=2)
@ -515,15 +499,13 @@ def test_prefill_schedule_max_seqs(use_v2_block_manager: bool):
assert len(remaining_waiting) == 1
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_prefill_schedule_max_lora(use_v2_block_manager: bool):
def test_prefill_schedule_max_lora():
"""
Test max lora is respected and prioritized.
"""
block_size = 4
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config,
use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
num_cpu_blocks=64,
num_gpu_blocks=64)
@ -570,14 +552,12 @@ def test_prefill_schedule_max_lora(use_v2_block_manager: bool):
assert budget.num_batched_tokens == 60
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager):
def test_prefill_schedule_no_block_manager_capacity():
"""
Test sequence cannot be scheduled due to block manager has no capacity.
"""
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_gpu_blocks=128,
num_cpu_blocks=128)
budget = create_token_budget()
@ -614,14 +594,12 @@ def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager):
assert len(remaining_waiting) == 0
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_decode_schedule_preempted(use_v2_block_manager: bool):
def test_decode_schedule_preempted():
"""
Test decodes cannot be scheduled and preempted.
"""
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=64,
num_gpu_blocks=64)
curr_loras = None
@ -660,14 +638,12 @@ def test_decode_schedule_preempted(use_v2_block_manager: bool):
assert output.blocks_to_copy == []
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_decode_swap_beam_search(use_v2_block_manager: bool):
def test_decode_swap_beam_search():
"""
Test best_of > 1 swap out blocks
"""
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_gpu_blocks=64,
num_cpu_blocks=64)
curr_loras = None
@ -716,14 +692,12 @@ def test_decode_swap_beam_search(use_v2_block_manager: bool):
assert output.blocks_to_copy == []
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool):
def test_schedule_decode_blocks_to_copy_update():
"""
Verify blocks_to_copy is updated.
"""
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=4,
scheduler = initialize_scheduler(block_size=4,
num_cpu_blocks=16,
num_gpu_blocks=16)
_, seq_group = create_dummy_prompt("1",
@ -754,11 +728,9 @@ def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool):
assert output.blocks_to_copy == [(2, 3)]
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_schedule_swapped_simple(use_v2_block_manager: bool):
def test_schedule_swapped_simple():
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size)
scheduler = initialize_scheduler(block_size=block_size)
curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = []
_, seq_group = create_dummy_prompt("1",
@ -785,11 +757,9 @@ def test_schedule_swapped_simple(use_v2_block_manager: bool):
assert blocks_to_swap_out == blocks_to_swap_in_reverse
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool):
def test_schedule_swapped_max_token_budget():
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=32,
num_gpu_blocks=32)
curr_loras = None
@ -822,11 +792,9 @@ def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool):
assert len(output.prefill_seq_groups) == 0
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_schedule_swapped_max_seqs(use_v2_block_manager: bool):
def test_schedule_swapped_max_seqs():
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=64,
num_gpu_blocks=64)
curr_loras = None
@ -859,12 +827,10 @@ def test_schedule_swapped_max_seqs(use_v2_block_manager: bool):
assert len(output.prefill_seq_groups) == 0
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_schedule_swapped_max_loras(use_v2_block_manager: bool):
def test_schedule_swapped_max_loras():
block_size = 4
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config,
use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
num_cpu_blocks=32,
num_gpu_blocks=32)
@ -894,11 +860,9 @@ def test_schedule_swapped_max_loras(use_v2_block_manager: bool):
assert len(curr_loras) == 1
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool):
def test_schedule_swapped_cannot_swap_in():
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=32,
num_gpu_blocks=32)
curr_loras = None
@ -927,11 +891,9 @@ def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool):
assert len(output.prefill_seq_groups) == 0
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_infeasible_swap(use_v2_block_manager: bool):
def test_infeasible_swap():
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=32,
num_gpu_blocks=32)
curr_loras = None
@ -961,11 +923,9 @@ def test_infeasible_swap(use_v2_block_manager: bool):
assert len(output.prefill_seq_groups) == 0
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_schedule_swapped_blocks_to_copy(use_v2_block_manager: bool):
def test_schedule_swapped_blocks_to_copy():
block_size = 4
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
block_size=block_size,
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=32,
num_gpu_blocks=32)
curr_loras = None

View File

@ -185,13 +185,14 @@ def test_metric_spec_decode(
) -> None:
k = 5
with vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
use_v2_block_manager=True) as vllm_model:
with vllm_runner(
model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
) as vllm_model:
# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
@ -242,7 +243,6 @@ def test_metric_spec_decode_interval(
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
use_v2_block_manager=True,
enforce_eager=True)
engine = LLMEngine.from_engine_args(engine_args)

View File

@ -17,7 +17,6 @@ NUM_PROMPTS = [10]
DEFAULT_SERVER_ARGS: List[str] = [
"--disable-log-requests",
"--use-v2-block-manager",
"--worker-use-ray",
"--gpu-memory-utilization",
"0.85",

View File

@ -76,7 +76,6 @@ def test_multi_step_llm(
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
enable_chunked_prefill=enable_chunked_prefill,
num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
@ -169,7 +168,6 @@ def test_multi_step_llm_w_prompt_logprobs(
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
@ -305,7 +303,6 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
max_model_len=48,
max_num_batched_tokens=48,
@ -324,7 +321,6 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
enable_chunked_prefill=True,
enable_prefix_caching=True,
num_scheduler_steps=num_scheduler_steps,

View File

@ -2,15 +2,9 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
from typing import List
import pytest
from tests.kernels.utils import override_backend_env_variable
from tests.utils import check_deprecated_block_manager_usage
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager_v1 import CachedBlockAllocator
from vllm.utils import Device
from ..models.utils import check_outputs_equal
@ -19,92 +13,11 @@ MODELS = [
]
@pytest.fixture(scope="module", autouse=True)
def check_deprecated_block_manager():
check_deprecated_block_manager_usage(
'tests/prefix_caching/test_prefix_caching.py')
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
def test_block_allocator(
block_size: int,
num_blocks: int,
):
block_hash = 1
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
# Allocate two PysicalTokenBlocks with the same hash and check
# that they are the same PhysicalTokenBlock
first_block = block_allocator.allocate(block_hash, 0)
second_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (second_block.ref_count == 2)
# Check metric: 1 hit of 2 queries
assert block_allocator.get_prefix_cache_hit_rate() == 0.5
# Free the first_block and confirm that the ref_count is correctly
# decremented on the second block
block_allocator.free(first_block)
assert (second_block.ref_count == 1)
# Free the second block
block_allocator.free(second_block)
# Reallocate the first block and confirm that, even after the block
# had its ref_count go to 0, we still get the same block back
first_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (first_block.block_hash == block_hash)
# Allocate one more time to get 3/4 hit rate for easy checking
block_allocator.allocate(block_hash, 0)
assert block_allocator.get_prefix_cache_hit_rate() == 0.75
@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
block_size = 16
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks):
# use i as the block_hash
blocks.append(block_allocator.allocate(i, 0))
#Free all blocks
for block in blocks:
block_allocator.free(block)
# Allocate a new block and confirm that it's the first block freed.
# I.E The Least Recently Used block
new_block_hash = block_size
new_block = block_allocator.allocate(new_block_hash, 0)
assert (new_block == blocks[0])
assert (new_block.block_hash == new_block_hash)
# Reallocate the second in blocks to remove it from the free list
realloc_block_hash = 1
realloc_block = block_allocator.allocate(realloc_block_hash, 0)
assert (realloc_block == blocks[realloc_block_hash])
assert (realloc_block.block_hash == realloc_block_hash)
# Allocate a new block and confirm that it's not the realloc_block,
# since the realloc_block shouldn't be in the free list
new_block_hash = block_size + 1
new_block = block_allocator.allocate(new_block_hash, 0)
assert (realloc_block != new_block)
assert (new_block.block_hash == new_block_hash)
assert (new_block.block_number == 2)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
def test_mixed_requests(
hf_runner,
vllm_runner,
@ -114,7 +27,6 @@ def test_mixed_requests(
dtype: str,
max_tokens: int,
cached_position: int,
use_v2_block_manager: bool,
monkeypatch,
) -> None:
"""
@ -132,7 +44,6 @@ def test_mixed_requests(
model,
dtype=dtype,
enable_prefix_caching=True,
use_v2_block_manager=use_v2_block_manager,
) as vllm_model:
# Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)

View File

@ -1,27 +1,15 @@
import pytest
from tests.utils import check_deprecated_block_manager_usage
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.fixture(scope="module", autouse=True)
def check_deprecated_block_manager():
check_deprecated_block_manager_usage(
'tests/spec_decode/e2e/test_compatibility.py')
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
@ -51,16 +39,11 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
sampling_params)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
@ -101,34 +84,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"use_v2_block_manager": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
"""Verify that speculative decoding with block manager v1 fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding requires usage of the V2"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)

View File

@ -43,9 +43,6 @@ PRECISION = "float32"
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -86,9 +83,6 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -143,9 +137,6 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
[{
"enforce_eager": False,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -191,9 +182,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
@ -235,9 +223,6 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
@ -283,9 +268,6 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,

View File

@ -12,8 +12,6 @@ MAIN_MODEL = "JackFram/llama-68m"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
@ -57,9 +55,6 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
@ -111,9 +106,6 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}])

View File

@ -17,9 +17,6 @@ from .conftest import run_equality_correctness_test_tp
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
# Required for spec decode.
"--use-v2-block-manager",
"--tensor-parallel-size",
"2"
]])
@ -74,9 +71,6 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
# Required for spec decode.
"--use_v2_block_manager",
"--tensor_parallel_size",
"2",

View File

@ -19,9 +19,6 @@ SPEC_MODEL = "JackFram/llama-68m"
[[
# Skip cuda graph recording for fast test.
"--enforce_eager",
# Required for spec decode.
"--use-v2-block-manager",
"--tensor-parallel-size",
"4",
]])
@ -71,9 +68,6 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
# Skip cuda graph recording for fast test.
"--enforce-eager",
# Required for spec decode.
"--use-v2-block-manager",
"--tensor-parallel-size",
"4",
]])

View File

@ -14,9 +14,6 @@ from .conftest import run_equality_correctness_test
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -67,9 +64,6 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -119,9 +113,6 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -173,9 +164,6 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -251,8 +239,6 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

View File

@ -45,9 +45,6 @@ PRECISION = "float32"
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -93,9 +90,6 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -151,9 +145,6 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
[{
"enforce_eager": False,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -204,9 +195,6 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
@ -253,9 +241,6 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
@ -306,9 +291,6 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
@ -356,9 +338,6 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,

View File

@ -47,9 +47,6 @@ PRECISION = "float32"
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -94,9 +91,6 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -149,9 +143,6 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -195,9 +186,6 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
@ -258,9 +246,6 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
@ -311,9 +296,6 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
@ -366,9 +348,6 @@ def test_mlp_e2e_greedy_correctness_with_padding(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
@ -419,9 +398,6 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
@ -469,9 +445,6 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])

View File

@ -55,9 +55,6 @@ from .conftest import (get_output_from_llm_generator,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
@ -124,9 +121,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@ -190,9 +184,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@ -246,9 +237,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
@ -303,9 +291,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@ -353,9 +338,6 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@ -404,9 +386,6 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
@ -454,9 +433,6 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
@ -514,9 +490,6 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -570,9 +543,6 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -611,9 +581,6 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -660,9 +627,6 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

View File

@ -35,9 +35,6 @@ from .conftest import run_equality_correctness_test
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@ -82,9 +79,6 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@ -145,9 +139,6 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
@ -195,9 +186,6 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -254,9 +242,6 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -303,7 +288,6 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,

View File

@ -17,9 +17,6 @@ SPEC_MODEL = "JackFram/llama-160m"
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# speculative model
"speculative_model": "JackFram/llama-160m",

View File

@ -678,12 +678,3 @@ def get_client_text_logprob_generations(
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for completion in completions for x in completion.choices]
def check_deprecated_block_manager_usage(test_name: str):
assert envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1 is True, (
f"To allow the use of deprecated BlockSpaceManagerV1, set the "
f"environment variable VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. "
f"You can run the tests with: "
f"`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest {test_name}`" #noqa
)

View File

@ -305,8 +305,6 @@ class FlashAttentionMetadataBuilder(
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
@ -355,9 +353,9 @@ class FlashAttentionMetadataBuilder(
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)

View File

@ -475,8 +475,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
@ -542,9 +540,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
is_profile_run = is_block_tables_empty(block_tables)
# Compute slot mapping.
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)

View File

@ -38,18 +38,12 @@ def is_block_tables_empty(block_tables: Union[None, Dict]):
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
context_len: int, sliding_window: int,
use_v2_block_manager: bool):
context_len: int, sliding_window: int):
"""
Compute the start index of slot mapping.
"""
start_idx = 0
if is_prompt and sliding_window is not None:
assert use_v2_block_manager or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - sliding_window)
return start_idx
@ -138,8 +132,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
@ -180,9 +172,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)

1
vllm/commit_id.py Normal file
View File

@ -0,0 +1 @@
__commit__ = "93ec62b8556e279d2c050bdc1c3247831bd39466"

View File

@ -949,7 +949,6 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
num_lookahead_slots: The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
@ -976,7 +975,6 @@ class SchedulerConfig:
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
use_v2_block_manager: bool = True,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
@ -1026,7 +1024,6 @@ class SchedulerConfig:
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.use_v2_block_manager = use_v2_block_manager
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
@ -1067,18 +1064,6 @@ class SchedulerConfig:
f"({self.num_scheduler_steps}) must be greater than or "
"equal to 1.")
if (not self.use_v2_block_manager \
and not envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1):
raise ValueError(
"The use of BlockSpaceManagerV1 is deprecated and will "
"be removed in a future release. Please switch to "
"BlockSpaceManagerV2 by setting --use-v2-block-manager to "
"True. If you wish to suppress this error temporarily, "
"you can set the environment variable "
"`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. If your use "
"case is not supported in BlockSpaceManagerV2, please "
"file an issue with detailed information.")
@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1
@ -1137,7 +1122,6 @@ class SpeculativeConfig:
speculative_disable_mqa_scorer: Optional[bool],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
disable_log_stats: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
@ -1178,9 +1162,6 @@ class SpeculativeConfig:
enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since its not
yet compatible with spec decode.
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
@ -1231,11 +1212,6 @@ class SpeculativeConfig:
"Speculative decoding and chunked prefill are "
f"currently mutually exclusive ({enable_chunked_prefill=}).")
if not use_v2_block_manager:
raise ValueError(
"Speculative decoding requires usage of the V2 "
"block manager. Enable it with --use-v2-block-manager.")
# TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported.
draft_revision = None

View File

@ -4,28 +4,6 @@ from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
STR_NOT_IMPL_ENC_DEC_SWA)
def _get_block_mgr_sliding_window_attr(block_mgr):
'''
BlockManagerV1 and BlockManagerV2 have slightly different
members related to sliding window attention (SWA). This
function extracts the appropriate member to use for determining
whether SWA is enabled.
Arguments:
* block_mgr: BlockManagerV1 or BlockManagerV2 instance
'''
if hasattr(block_mgr, 'block_sliding_window'):
return block_mgr.block_sliding_window
if hasattr(block_mgr, 'max_block_sliding_window'):
return block_mgr.max_block_sliding_window
raise AttributeError("Block manager instance has neither " + \
"block_sliding_window nor " + \
"max_block_sliding_window attributes.")
def check_no_caching_or_swa_for_blockmgr_encdec(
block_mgr, seq_group: SequenceGroup) -> None:
'''
@ -41,7 +19,7 @@ def check_no_caching_or_swa_for_blockmgr_encdec(
'''
if seq_group.is_encoder_decoder():
if _get_block_mgr_sliding_window_attr(block_mgr) is not None:
if block_mgr.max_block_sliding_window is not None:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA)
if block_mgr.enable_caching:

View File

@ -17,7 +17,7 @@ SeqId = int
EncoderSeqId = str
class BlockSpaceManagerV2(BlockSpaceManager):
class SelfAttnBlockSpaceManager(BlockSpaceManager):
"""BlockSpaceManager which manages the allocation of KV cache.
It owns responsibility for allocation, swapping, allocating memory for

View File

@ -1,743 +0,0 @@
"""A block manager that manages token blocks."""
import math
from abc import ABC, abstractmethod
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.block.common import CacheMetricData
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
logger = init_logger(__name__)
class BlockAllocatorBase(ABC):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
@abstractmethod
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
pass
@abstractmethod
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
pass
@abstractmethod
def free(self, block: PhysicalTokenBlock) -> None:
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
pass
@abstractmethod
def get_num_total_blocks(self) -> int:
pass
@abstractmethod
def contains_block(self, block_hash: int) -> bool:
pass
@abstractmethod
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass
@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
self.current_num_blocks = 0
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
self.evictor: Evictor = make_evictor(eviction_policy)
self.default_hash_ctr = count()
self.cache_metric_data = CacheMetricData()
def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks:
block = self.evictor.evict()
block.block_hash = block_hash
block.num_hashed_tokens = num_hashed_tokens
return block
block = PhysicalTokenBlock(device=self.device,
block_number=self.current_num_blocks,
block_size=self.block_size,
block_hash=block_hash,
num_hashed_tokens=num_hashed_tokens)
self.current_num_blocks += 1
return block
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if block_hash is None:
block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor:
assert block_hash not in self.cached_blocks
block = self.evictor.remove(block_hash)
assert block.ref_count == 0
self.cached_blocks[block_hash] = block
if block_hash in self.cached_blocks:
self.cache_metric_data.query(hit=True)
else:
self.cache_metric_data.query(hit=False)
self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash]
assert block.block_hash == block_hash
block.ref_count += 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
assert block.block_hash not in self.evictor
self.evictor.add(block)
# Remove the block from the cached_blocks
del self.cached_blocks[block.block_hash]
def get_num_free_blocks(self) -> int:
return (self.num_blocks - self.current_num_blocks +
self.evictor.num_blocks)
def get_num_total_blocks(self) -> int:
return self.num_blocks
def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
# Update the hash of block and the cached_blocks dictionary.
assert not self.contains_block(block_hash)
old_hash = block.block_hash
block.block_hash = block_hash
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
def get_prefix_cache_hit_rate(self) -> float:
return self.cache_metric_data.get_hit_rate()
class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(
self,
device: Device,
block_size: int,
num_blocks: int,
) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
# Initialize the free blocks.
self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks):
block = PhysicalTokenBlock(device=device,
block_number=i,
block_size=block_size,
block_hash=-1,
num_hashed_tokens=0)
self.free_blocks.append(block)
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if not self.free_blocks:
raise ValueError("Out of memory! No free blocks are available.")
block = self.free_blocks.pop()
block.ref_count = 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
self.free_blocks.append(block)
def get_num_free_blocks(self) -> int:
return len(self.free_blocks)
def get_num_total_blocks(self) -> int:
return self.num_blocks
def contains_block(self, block_hash: int) -> bool:
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
def get_prefix_cache_hit_rate(self) -> float:
return -1
class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks."""
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
enable_caching: bool = False,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
if enable_caching and sliding_window is not None:
raise NotImplementedError(
"Sliding window is not allowed with prefix caching enabled!")
self.block_sliding_window = None
if sliding_window is not None:
# Round up to nearest block size to regularize sliding window
# allocation sizes.
self.block_sliding_window = math.ceil(sliding_window / block_size)
self.watermark = watermark
assert watermark >= 0.0
self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks)
if self.enable_caching:
logger.info("Automatic prefix caching is enabled.")
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = UncachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}
# Mapping: req_id -> BlockTable
# Note that each SequenceGroup has a unique
# request ID
self.cross_block_tables: Dict[str, BlockTable] = {}
def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
return 0 if seq is None else seq.n_blocks
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
assert (num_lookahead_slots == 0
), "lookahead allocation not supported in BlockSpaceManagerV1"
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
self_num_required_blocks = self._get_seq_num_required_blocks(
seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
cross_num_required_blocks = self._get_seq_num_required_blocks(
seq_group.get_encoder_seq())
num_required_blocks = self_num_required_blocks + \
cross_num_required_blocks
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks):
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def _allocate_sequence(self, \
seq: Optional[Sequence], \
ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = self._get_seq_num_required_blocks(seq)
block_table: BlockTable = BlockTable()
assert seq is not None
for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window]
# Set the reference counts of the token blocks.
block.ref_count = ref_count
elif not is_encoder_decoder and self.enable_caching:
block = self.gpu_allocator.allocate(
seq.hash_of_block(logical_idx),
seq.num_hashed_tokens_of_block(logical_idx))
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = ref_count
block_table.append(block)
return block_table
def allocate(self, seq_group: SequenceGroup) -> None:
is_encoder_decoder = seq_group.is_encoder_decoder()
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
# Allocate decoder sequences
#
# NOTE: Here we assume that all sequences in the group have the same
# decoder prompt.
wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
seq = wait_seqs[0]
block_table: BlockTable = \
self._allocate_sequence(seq,
seq_group.num_seqs(),
is_encoder_decoder)
# Assign the self-attention block tables for each sequence.
if len(wait_seqs) == 1:
self.block_tables[seq.seq_id] = block_table
else:
for seq in wait_seqs:
self.block_tables[seq.seq_id] = block_table.copy()
# Allocate encoder sequence
if is_encoder_decoder:
# A SequenceGroup has only a single encoder sequence (at most),
# thus allocate with a ref count of 1
block_table = self._allocate_sequence(seq_group.get_encoder_seq(),
1, is_encoder_decoder)
# Assign the cross-attention block table for the SequenceGroup.
self.cross_block_tables[seq_group.request_id] = block_table
def can_append_slots(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> bool:
assert (num_lookahead_slots == 0
), "lookahead allocation not supported in BlockSpaceManagerV1"
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def _promote_last_block(
self,
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
assert self.enable_caching
# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(seq.n_blocks - 1)
# if new_hash is already in the cached table, then free last_block
# and return the cached version
if self.gpu_allocator.contains_block(new_hash):
self.gpu_allocator.free(last_block)
return self.gpu_allocator.allocate(new_hash)
else:
self.gpu_allocator.update_hash(new_hash, last_block)
return last_block
def _is_last_block_full(
self,
seq: Sequence,
) -> bool:
token_ids_len = seq.data.get_len()
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
def _maybe_promote_last_block(
self,
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
if self._is_last_block_full(seq):
return self._promote_last_block(seq, last_block)
else:
return last_block
def _allocate_last_physical_block(
self,
seq: Sequence,
) -> PhysicalTokenBlock:
# Called before a new block is appended.
# This is in charge of allocating a new physical block (to be appended).
# None if the last block is not full. Otherwise, we set it to the
# content hash.
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
n_blocks = seq.n_blocks
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(n_blocks - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)
# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
# prefix tokens)
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
# If the block_hash is None, then the block is not full.
# If the block is not full, then we expect it to have a refcount of 1.
if block_hash is None:
assert new_block.ref_count == 1
return new_block
def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int = 0,
) -> List[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
n_blocks = seq.n_blocks
block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block
if len(block_table) < n_blocks:
# Currently this code only supports adding one physical block
assert len(block_table) == n_blocks - 1
if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
# reuse a block
block_table.append(block_table[len(block_table) %
self.block_sliding_window])
else:
# The sequence hash a new logical block.
# Allocate a new physical block.
new_block = self._allocate_last_physical_block(seq)
block_table.append(new_block)
return []
# We want to append the token to the last physical block.
last_block = block_table[-1]
assert last_block.device == Device.GPU
if last_block.ref_count == 1:
# Not shared with other sequences. Appendable.
if self.enable_caching:
# If the last block is now complete, we may reuse an old block
# to save memory.
maybe_new_block = self._maybe_promote_last_block(
seq, last_block)
block_table[-1] = maybe_new_block
return []
else:
# The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens.
new_block = self._allocate_last_physical_block(seq)
block_table[-1] = new_block
self.gpu_allocator.free(last_block)
return [(last_block.block_number, new_block.block_number)]
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy()
# When using a sliding window, blocks will be eventually reused.
# In this case the block tables will contain repeated blocks.
# When forking, we must make sure that each block's `ref_count`
# is only incremented by one, so we deduplicate them by wrapping
# them in a set.
for block in set(src_block_table):
block.ref_count += 1
def _get_physical_blocks(
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
request_id = seq_group.request_id
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.get_seqs():
if seq.is_finished():
continue
blocks.update(self.block_tables[seq.seq_id])
# Cross-attention blocks
if seq_group.is_encoder_decoder():
blocks.update(self.cross_block_tables[request_id])
return list(blocks)
def can_swap_in(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation"
blocks = self._get_physical_blocks(seq_group)
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
if seq_group.is_encoder_decoder():
num_swapped_seqs += 1
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
# NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append_slot().
num_required_blocks = len(blocks) + num_swapped_seqs
if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
return AllocStatus.NEVER
elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def _swap_block_table(
self, block_table: BlockTable, src_allocator: BlockAllocatorBase,
dest_allocator: BlockAllocatorBase,
mapping: Dict[PhysicalTokenBlock,
PhysicalTokenBlock]) -> BlockTable:
new_block_table: BlockTable = BlockTable()
for from_block in block_table:
if from_block in mapping:
to_block = mapping[from_block]
to_block.ref_count += 1
else:
to_block = dest_allocator.allocate(
from_block.block_hash, from_block.num_hashed_tokens)
mapping[from_block] = to_block
new_block_table.append(to_block)
# Free the source block swapped in to destination.
src_allocator.free(from_block)
return new_block_table
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
request_id = seq_group.request_id
# CPU block -> GPU block.
# dict is efficient in lookup `if cpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.cpu_allocator, self.gpu_allocator,
mapping)
if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \
self._swap_block_table(self.cross_block_tables[request_id],
self.cpu_allocator,
self.gpu_allocator,
mapping)
return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()]
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
blocks = self._get_physical_blocks(seq_group)
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
request_id = seq_group.request_id
# GPU block -> CPU block.
# dict is efficient in lookup `if gpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.gpu_allocator, self.cpu_allocator,
mapping)
if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \
self._swap_block_table(self.cross_block_tables[request_id],
self.gpu_allocator,
self.cpu_allocator,
mapping)
return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()]
def _free_block_table(self, block_table: BlockTable) -> None:
# when using a sliding window, each seq will only use up
# to `self.block_sliding_window` blocks. When freeing
# the block table, we must make sure to not free blocks more
# than once. If no sliding window is used, there is no block
# reuse in the block table, so we must free all blocks.
blocks_to_free = (block_table[-self.block_sliding_window:]
if self.block_sliding_window is not None else
block_table)
for block in set(blocks_to_free):
if block.device == Device.GPU:
self.gpu_allocator.free(block)
else:
self.cpu_allocator.free(block)
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
block_table = self.block_tables[seq.seq_id]
self._free_block_table(block_table)
del self.block_tables[seq.seq_id]
def free_cross(self, seq_group: SequenceGroup) -> None:
if seq_group.request_id not in self.cross_block_tables:
# Already freed or hasn't ben scheduled yet.
return
block_table = self.cross_block_tables[seq_group.request_id]
self._free_block_table(block_table)
del self.cross_block_tables[seq_group.request_id]
def reset(self) -> None:
# Free decoder block tables
for block_table in self.block_tables.values():
self._free_block_table(block_table)
self.block_tables.clear()
# Free cross-attention block tables
for block_table in self.cross_block_tables.values():
self._free_block_table(block_table)
self.cross_block_tables.clear()
def get_block_table(self, seq: Sequence) -> List[int]:
return self.block_tables[seq.seq_id].ids()
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
block_table = self.cross_block_tables[seq_group.request_id]
return [block.block_number for block in block_table]
def get_num_free_gpu_blocks(self) -> int:
return self.gpu_allocator.get_num_free_blocks()
def get_num_free_cpu_blocks(self) -> int:
return self.cpu_allocator.get_num_free_blocks()
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
if self.enable_caching:
# Update the last accessed time of all the blocks accessed
# in this step.
block_table = self.block_tables[seq.seq_id]
for block in block_table:
block.last_accessed = access_time
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
if seq.seq_id not in self.block_tables:
return
# When chunked prefill is enabled, the computed full blocks
# should be calculated based on the number of computed tokens.
max_computed_tokens = (seq.data.get_num_computed_tokens() +
token_chunk_size)
computed_full_blocks = max_computed_tokens // self.block_size
block_table = self.block_tables[seq.seq_id]
if computed_full_blocks == 0:
return
for i in reversed(range(computed_full_blocks)):
if block_table[i].computed:
break
block_table[i].computed = True
def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
if seq.seq_id not in self.block_tables:
return []
block_table = self.block_tables[seq.seq_id]
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
return [
b.block_number
for b in takewhile(lambda b: b.computed, block_table[:-1])
]
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
"""
# Can return non-empty result only with prefix caching enabled.
if not self.enable_caching:
return []
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq, token_chunk_size)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
return self.gpu_allocator.get_prefix_cache_hit_rate()
if device == Device.CPU:
return self.cpu_allocator.get_prefix_cache_hit_rate()
raise ValueError(f"Invalid device: {device}")

View File

@ -28,13 +28,9 @@ class BlockSpaceManager(ABC):
def get_block_space_manager_class(version: str):
version = version.lower()
if version == "v1":
from vllm.core.block_manager_v1 import BlockSpaceManagerV1
return BlockSpaceManagerV1
if version == "v2":
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
return BlockSpaceManagerV2
if version == "selfattn":
from vllm.core.block_manager import SelfAttnBlockSpaceManager
return SelfAttnBlockSpaceManager
if version == "placeholder":
from vllm.core.placeholder_block_space_manager import (

View File

@ -312,9 +312,7 @@ class Scheduler:
# LoRAs. This should be improved in the future.
self.lora_config = lora_config
version = "v1"
if self.scheduler_config.use_v2_block_manager:
version = "v2"
version = "selfattn"
if (self.scheduler_config.embedding_mode
or self.cache_config.is_attention_free):
version = "placeholder"

View File

@ -373,12 +373,13 @@ class EngineArgs:
action='store_true',
help='Disables sliding window, '
'capping to sliding window size')
parser.add_argument(
'--use-v2-block-manager',
default=EngineArgs.use_v2_block_manager,
action='store_true',
help='Use BlockSpaceMangerV2. By default this is set to True. '
'Set to False to use BlockSpaceManagerV1')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='[DEPRECATED] block manager v1 has been '
'removed and SelfAttnBlockSpaceManager (i.e. '
'block manager v2) is now the default. '
'Setting this flag to True or False'
' has no effect on vLLM behavior.')
parser.add_argument(
'--num-lookahead-slots',
type=int,
@ -969,12 +970,6 @@ class EngineArgs:
"in low performance due to small KV cache space. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
self.use_v2_block_manager = True
logger.warning(
"Enabled BlockSpaceManagerV2 because it is "
"required for multi-step (--num-scheduler-steps > 1)")
speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
@ -990,7 +985,6 @@ class EngineArgs:
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
@ -1021,11 +1015,20 @@ class EngineArgs:
if speculative_config is None \
else speculative_config.num_lookahead_slots
if not self.use_v2_block_manager:
logger.warning(
"[DEPRECATED] Block manager v1 has been removed, "
"and setting --use-v2-block-manager to True or False has "
"no effect on vLLM behavior. Please remove "
"--use-v2-block-manager in your engine argument. "
"If your use case is not supported by "
"SelfAttnBlockSpaceManager (i.e. block manager v2),"
" please file an issue with detailed information.")
scheduler_config = SchedulerConfig(
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
use_v2_block_manager=self.use_v2_block_manager,
num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
@ -1081,13 +1084,6 @@ class EngineArgs:
or "all" in detailed_trace_modules,
)
if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled
and not scheduler_config.use_v2_block_manager):
raise ValueError(
"Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window.")
return EngineConfig(
model_config=model_config,
cache_config=cache_config,

View File

@ -247,7 +247,7 @@ class LLMEngine:
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
@ -280,7 +280,6 @@ class LLMEngine:
observability_config,
model_config.seed,
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps,
scheduler_config.chunked_prefill_enabled,
scheduler_config.multi_step_stream_outputs,

View File

@ -64,7 +64,6 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_DISABLED_KERNELS: List[str] = []
@ -427,11 +426,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_SKIP_P2P_CHECK":
lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1",
# If set, allowing the use of deprecated block manager V1
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
) == "1",
# List of quantization kernels that should be disabled, used for testing
# and performance comparisons. Currently only affects MPLinearKernel
# selection

View File

@ -574,17 +574,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
curr_sliding_window_block = self.sliding_window_blocks
if self.scheduler_config.use_v2_block_manager:
# number of elements in last block
suff_len = inter_data.seq_lens[seq_idx] % self.block_size
sliding_seq_len = min(
inter_data.seq_lens[seq_idx],
self.block_aligned_sliding_window + suff_len)
if suff_len > 0:
curr_sliding_window_block += 1
else:
sliding_seq_len = min(inter_data.seq_lens[seq_idx],
self.sliding_window)
# number of elements in last block
suff_len = inter_data.seq_lens[seq_idx] % self.block_size
sliding_seq_len = min(inter_data.seq_lens[seq_idx],
self.block_aligned_sliding_window + suff_len)
if suff_len > 0:
curr_sliding_window_block += 1
inter_data.curr_sliding_window_blocks[
seq_idx] = curr_sliding_window_block