[Core] Deprecating block manager v1 and make block manager v2 default (#8704)
Removing the block manager v1. This is the initial piece of prefix-caching-centric design. In order to achieve prefix-caching-centric design, we need to simplify the code path so that we only use v2 block manager (which has much higher performance on prefix caching).
This commit is contained in:
parent
5eda21e773
commit
81ede99ca4
@ -77,8 +77,8 @@ steps:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_chunked_prefill
|
||||
commands:
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
|
||||
- label: Core Test # 10min
|
||||
mirror_hardwares: [amd]
|
||||
@ -88,11 +88,7 @@ steps:
|
||||
- vllm/distributed
|
||||
- tests/core
|
||||
commands:
|
||||
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core/test_scheduler.py
|
||||
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/test_chunked_prefill_scheduler.py
|
||||
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness.py
|
||||
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness_sliding_window.py
|
||||
- pytest -v -s core --ignore=core/block/e2e/test_correctness.py --ignore=core/test_scheduler.py --ignore=core/test_chunked_prefill_scheduler.py --ignore=core/block/e2e/test_correctness.py --ignore=core/block/e2e/test_correctness_sliding_window.py
|
||||
- pytest -v -s core
|
||||
|
||||
- label: Entrypoints Test # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -192,8 +188,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/prefix_caching
|
||||
commands:
|
||||
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s prefix_caching/test_prefix_caching.py
|
||||
- pytest -v -s prefix_caching --ignore=prefix_caching/test_prefix_caching.py
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
- label: Samplers Test # 36min
|
||||
source_file_dependencies:
|
||||
@ -217,8 +212,7 @@ steps:
|
||||
- tests/spec_decode
|
||||
commands:
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s spec_decode/e2e/test_compatibility.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_compatibility.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amd]
|
||||
@ -405,7 +399,7 @@ steps:
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
||||
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
|
||||
|
@ -38,7 +38,6 @@ def main(args: argparse.Namespace):
|
||||
quantization_param_path=args.quantization_param_path,
|
||||
device=args.device,
|
||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||
use_v2_block_manager=args.use_v2_block_manager,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
download_dir=args.download_dir,
|
||||
block_size=args.block_size,
|
||||
@ -221,9 +220,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="Enable automatic prefix caching")
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
default=EngineArgs.use_v2_block_manager)
|
||||
parser.add_argument(
|
||||
"--ray-workers-use-nsight",
|
||||
action='store_true',
|
||||
|
@ -33,7 +33,6 @@ from typing import List, Optional, Tuple
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
try:
|
||||
@ -134,7 +133,6 @@ def main(args):
|
||||
tokenizer_mode='auto',
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
use_v2_block_manager=args.use_v2_block_manager,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
enable_prefix_caching=args.enable_prefix_caching)
|
||||
|
||||
@ -176,10 +174,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='enable prefix caching')
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
default=EngineArgs.use_v2_block_manager,
|
||||
help='Use BlockSpaceMangerV2')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
|
@ -86,7 +86,6 @@ def run_vllm(
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
@ -113,7 +112,6 @@ def run_vllm(
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
)
|
||||
|
||||
@ -176,7 +174,6 @@ async def run_vllm_async(
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
@ -204,7 +201,6 @@ async def run_vllm_async(
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
worker_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
@ -341,8 +337,7 @@ def main(args: argparse.Namespace):
|
||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||
args.use_v2_block_manager, args.download_dir, args.load_format,
|
||||
args.disable_async_output_proc
|
||||
args.download_dir, args.load_format, args.disable_async_output_proc
|
||||
]
|
||||
|
||||
if args.async_engine:
|
||||
@ -471,10 +466,6 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=1,
|
||||
help="Maximum number of forward steps per scheduler call.")
|
||||
parser.add_argument("--use-v2-block-manager",
|
||||
action='store_true',
|
||||
default=EngineArgs.use_v2_block_manager,
|
||||
help="Enable block manager v2.")
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
|
@ -16,7 +16,6 @@ def main(args):
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
use_v2_block_manager=args.use_v2_block_manager,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||
@ -56,8 +55,5 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='enable prefix caching')
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
help='Use BlockSpaceMangerV2')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -30,7 +30,6 @@ The following code configures vLLM in an offline mode to use speculative decodin
|
||||
tensor_parallel_size=1,
|
||||
speculative_model="facebook/opt-125m",
|
||||
num_speculative_tokens=5,
|
||||
use_v2_block_manager=True,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -104,7 +103,6 @@ matching n-grams in the prompt. For more information read `this thread. <https:/
|
||||
speculative_model="[ngram]",
|
||||
num_speculative_tokens=5,
|
||||
ngram_prompt_lookup_max=4,
|
||||
use_v2_block_manager=True,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -135,7 +133,6 @@ For more information see `this blog <https://pytorch.org/blog/hitchhikers-guide-
|
||||
tensor_parallel_size=4,
|
||||
speculative_model="ibm-fms/llama3-70b-accelerator",
|
||||
speculative_draft_tensor_parallel_size=1,
|
||||
use_v2_block_manager=True,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
|
@ -50,8 +50,6 @@ if __name__ == "__main__":
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-2-13b-chat-hf",
|
||||
speculative_model="ibm-fms/llama-13b-accelerator",
|
||||
# These are currently required for MLPSpeculator decoding
|
||||
use_v2_block_manager=True,
|
||||
)
|
||||
|
||||
print("With speculation")
|
||||
|
@ -12,7 +12,7 @@ from contextlib import nullcontext
|
||||
import pytest
|
||||
|
||||
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||
from ..utils import check_deprecated_block_manager_usage, multi_gpu_test
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
@ -20,12 +20,6 @@ MODELS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def check_deprecated_block_manager():
|
||||
check_deprecated_block_manager_usage(
|
||||
'tests/basic_correctness/test_chunked_prefill.py')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@ -197,7 +191,6 @@ def test_models_with_fp8_kv_cache(
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("chunk_size", [30, 32])
|
||||
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
@ -206,7 +199,6 @@ def test_with_prefix_caching(
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
chunk_size: int,
|
||||
use_v2_block_manager: bool,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
"""
|
||||
@ -234,7 +226,6 @@ def test_with_prefix_caching(
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=enable,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
|
@ -2,18 +2,11 @@ from itertools import cycle
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import check_deprecated_block_manager_usage
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_token_ids_from_llm_generator
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def check_deprecated_block_manager():
|
||||
check_deprecated_block_manager_usage(
|
||||
'tests/core/block/e2e/test_correctness.py')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@ -28,32 +21,32 @@ def check_deprecated_block_manager():
|
||||
"num_gpu_blocks_override": 5 * (64 + 1),
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"use_v2_block_manager": True,
|
||||
"preemption_mode": "swap"
|
||||
}, {
|
||||
"use_v2_block_manager": True,
|
||||
"preemption_mode": "recompute"
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
def test_block_manager_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator, batch_size):
|
||||
"""Verify block manager v2 produces same outputs as block manager v1, even
|
||||
when there is preemption.
|
||||
"""Verify block manager produces same outputs even when there is preemption.
|
||||
|
||||
This constructs two LLM, each with limited number of GPU blocks. The limit
|
||||
is decided such that as the sequences in the batch grow, sequences must be
|
||||
preempted and removed from cache.
|
||||
|
||||
If the output token ids are equivalent, then we have confidence that the KV
|
||||
cache is not corrupted in the v2 block manager.
|
||||
cache is not corrupted.
|
||||
|
||||
NOTE: We want a significant number of generated tokens so that any incorrect
|
||||
KV mapping has time to build up error.
|
||||
|
||||
NOTE(Kuntai): Though we have removed block manager v1, this test is still
|
||||
useful as it asserts the behavior of block manager v2 (now it is called
|
||||
SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we
|
||||
keep this test.
|
||||
"""
|
||||
output_len = 1024
|
||||
temperature = 0.0
|
||||
@ -77,11 +70,9 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids from block manager v2')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
@ -104,9 +95,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Lookahead scheduling only supported in v2 block manager.
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
@ -218,26 +206,22 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
"max_num_seqs": 10,
|
||||
}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [
|
||||
{
|
||||
"use_v2_block_manager": False,
|
||||
},
|
||||
{},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"use_v2_block_manager": True,
|
||||
"num_lookahead_slots": 0,
|
||||
},
|
||||
{
|
||||
"use_v2_block_manager": True,
|
||||
"num_lookahead_slots": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
|
||||
def test_chunked_prefill_block_manager(baseline_llm_generator,
|
||||
test_llm_generator, batch_size):
|
||||
"""Verify that chunked prefill works with BlockManagerV2, with and without
|
||||
lookahead scheduling.
|
||||
"""Verify that chunked prefill works with SelfAttnBlockSpaceManager,
|
||||
with and without lookahead scheduling.
|
||||
"""
|
||||
output_len = 32
|
||||
temperature = 0.0
|
||||
@ -258,11 +242,11 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids with BlockManagerV1')
|
||||
print('Getting token ids with BlockManager')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids with BlockManagerV2')
|
||||
print('Getting token ids with BlockManager, with lookahead slots.')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
@ -290,32 +274,32 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
|
||||
"enable_prefix_caching": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"use_v2_block_manager": True,
|
||||
"preemption_mode": "swap"
|
||||
}, {
|
||||
"use_v2_block_manager": True,
|
||||
"preemption_mode": "recompute"
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
|
||||
def test_block_manager_prefix_caching_enabled_with_preemption(
|
||||
baseline_llm_generator, test_llm_generator, batch_size):
|
||||
"""Verify block manager v2 produces same outputs as block manager v1, even
|
||||
when there is preemption.
|
||||
"""Verify block manager produces same outputs even when there is preemption.
|
||||
|
||||
This constructs two LLM, each with limited number of GPU blocks. The limit
|
||||
is decided such that as the sequences in the batch grow, sequences must be
|
||||
preempted and removed from cache.
|
||||
|
||||
If the output token ids are equivalent, then we have confidence that the KV
|
||||
cache is not corrupted in the v2 block manager.
|
||||
cache is not corrupted.
|
||||
|
||||
NOTE: We want a significant number of generated tokens so that any incorrect
|
||||
KV mapping has time to build up error.
|
||||
|
||||
NOTE(Kuntai): Though we have removed block manager v1, this test is still
|
||||
useful as it asserts the behavior of block manager v2 (now it is called
|
||||
SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we
|
||||
keep this test.
|
||||
"""
|
||||
output_len = 1024
|
||||
temperature = 0.0
|
||||
@ -339,11 +323,11 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
print('Getting token ids from block manager')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids from block manager v2')
|
||||
print('Getting token ids from block manager, with preemption')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
@ -366,9 +350,6 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
|
||||
# Allow only 5 sequences of ~1024 tokens in worst case.
|
||||
"block_size": 16,
|
||||
"num_gpu_blocks_override": 5 * (64 + 1),
|
||||
|
||||
# Test APC in v2 block
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
@ -444,9 +425,6 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
|
||||
"max_model_len": 48,
|
||||
"block_size": 16,
|
||||
"num_gpu_blocks_override": 3,
|
||||
|
||||
# Test APC in v2 block
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
|
@ -3,7 +3,6 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import check_deprecated_block_manager_usage
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from .conftest import get_text_from_llm_generator
|
||||
@ -13,12 +12,6 @@ MODEL = "bigcode/starcoder2-3b"
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def check_deprecated_block_manager():
|
||||
check_deprecated_block_manager_usage(
|
||||
'tests/core/block/e2e/test_correctness_sliding_window.py')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@ -31,10 +24,8 @@ def check_deprecated_block_manager():
|
||||
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
|
||||
@ -55,7 +46,6 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
baseline_texts = get_text_from_llm_generator(baseline_llm_generator,
|
||||
prompts,
|
||||
sampling_params,
|
||||
@ -91,10 +81,7 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
|
||||
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"use_v2_block_manager": True,
|
||||
"enable_chunked_prefill": True
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}])
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
|
||||
|
@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||
STR_NOT_IMPL_ENC_DEC_SWA)
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
from vllm.core.block_manager import SelfAttnBlockSpaceManager
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.sequence import Logprob, SequenceStatus
|
||||
from vllm.utils import chunk_list
|
||||
@ -17,7 +17,7 @@ from ..utils import (create_dummy_prompt, create_seq_group,
|
||||
@pytest.mark.parametrize("watermark", [0.0, 0.5])
|
||||
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
|
||||
num_gpu_blocks: int, watermark: float):
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_manager = SelfAttnBlockSpaceManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
@ -63,7 +63,7 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int,
|
||||
num_seqs_per_group: int,
|
||||
num_gpu_blocks: int,
|
||||
watermark: float):
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_manager = SelfAttnBlockSpaceManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
@ -117,16 +117,16 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int,
|
||||
'''
|
||||
SWA short for Sliding Window Attention.
|
||||
|
||||
At time of writing block manager v2 does not support SWA.
|
||||
At time of writing block manager does not support SWA.
|
||||
|
||||
However even when SWA is implemented for block manager v2,
|
||||
However even when SWA is implemented for block manager,
|
||||
there will still most likely be a separate workstream required
|
||||
to enable SWA for encoder/decoder models.
|
||||
|
||||
Therefore this test enforces that one of the following cases
|
||||
hold true:
|
||||
1. Block manager v2 does not support SWA at all (true at time of writing)
|
||||
2. Block manager v2 fails with NotImplementError when SWA is enabled
|
||||
1. Block manager does not support SWA at all (true at time of writing)
|
||||
2. Block manager fails with NotImplementError when SWA is enabled
|
||||
AND a SequenceGroup with an encoder sequence (i.e. in support of an
|
||||
encoder/decoder model) is passed into can_allocate() as an argument
|
||||
|
||||
@ -135,7 +135,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int,
|
||||
'''
|
||||
|
||||
with pytest.raises((NotImplementedError, AssertionError)) as exc_info:
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_manager = SelfAttnBlockSpaceManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
@ -158,7 +158,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int,
|
||||
block_manager.can_allocate(seq_group)
|
||||
|
||||
# Assert that either
|
||||
# 1. Block manager v2 constructor fails with assertion that sliding window
|
||||
# 1. Block manager constructor fails with assertion that sliding window
|
||||
# is not yet supported (most likely near-term outcome at time of
|
||||
# writing), or
|
||||
# 2. can_allocate() fails with NotImplementedError due to combination of
|
||||
@ -177,7 +177,7 @@ def test_can_allocate_encoder_decoder_fails_with_prefix_cache(
|
||||
block_size: int, num_seqs_per_group: int, num_gpu_blocks: int,
|
||||
watermark: float):
|
||||
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_manager = SelfAttnBlockSpaceManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
@ -217,7 +217,7 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
watermark = 0.1
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_manager = SelfAttnBlockSpaceManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
@ -269,7 +269,7 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
|
||||
"""Verify blocks number on src/desc device is correct after swapping in/out
|
||||
sequence group (not missing or extra blocks).
|
||||
"""
|
||||
block_manager = BlockSpaceManagerV2(block_size,
|
||||
block_manager = SelfAttnBlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0,
|
||||
@ -277,6 +277,7 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
|
||||
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
|
||||
prompt.status = SequenceStatus.WAITING
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Emulate a forward pass by appending a single token.
|
||||
# The block manager then knows how many unprocessed
|
||||
# tokens will be written in the next forward pass.
|
||||
@ -321,7 +322,7 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
|
||||
can be swapped in/out.
|
||||
"""
|
||||
num_cpu_blocks = num_gpu_blocks
|
||||
block_manager = BlockSpaceManagerV2(block_size,
|
||||
block_manager = SelfAttnBlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0,
|
||||
@ -382,7 +383,7 @@ def test_swap_in_infeasible(num_lookahead_slots, enable_caching):
|
||||
block_size = 8
|
||||
num_cpu_blocks = 1
|
||||
num_gpu_blocks = 1
|
||||
block_manager = BlockSpaceManagerV2(block_size,
|
||||
block_manager = SelfAttnBlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0,
|
||||
@ -434,7 +435,7 @@ def test_sliding_window(block_size, prompt_len, num_slots_to_append,
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
watermark = 0.1
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_manager = SelfAttnBlockSpaceManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
@ -474,7 +475,7 @@ def test_sliding_window(block_size, prompt_len, num_slots_to_append,
|
||||
seq.data.update_num_computed_tokens(prompt_len)
|
||||
check_used(num_blocks(prompt_len))
|
||||
|
||||
# this is how we compute it in BlockSpaceManagerV2.__init__
|
||||
# this is how we compute it in SelfAttnBlockSpaceManager.__init__
|
||||
sliding_blocks = (sliding_window // block_size) + 2
|
||||
# plus one block for null block
|
||||
sliding_blocks += 1
|
@ -1,637 +0,0 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||
STR_NOT_IMPL_ENC_DEC_SWA)
|
||||
from vllm.core.block_manager_v1 import (BlockSpaceManagerV1,
|
||||
UncachedBlockAllocator)
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder
|
||||
|
||||
|
||||
def test_block_allocator_allocate():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
|
||||
# Allocate all available cpu blocks.
|
||||
num_free = num_cpu_blocks
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
for _ in range(num_cpu_blocks):
|
||||
block = cpu_allocator.allocate()
|
||||
num_free -= 1
|
||||
|
||||
assert block not in cpu_allocator.free_blocks
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cpu_allocator.allocate()
|
||||
|
||||
|
||||
def test_block_allocator_free():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
|
||||
# Allocate all available cpu blocks.
|
||||
blocks: List[PhysicalTokenBlock] = []
|
||||
for _ in range(num_cpu_blocks):
|
||||
block = cpu_allocator.allocate()
|
||||
blocks.append(block)
|
||||
assert block not in cpu_allocator.free_blocks
|
||||
|
||||
# Free all allocated cpu blocks.
|
||||
num_free = 0
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
for block in blocks:
|
||||
cpu_allocator.free(block)
|
||||
num_free += 1
|
||||
assert block in cpu_allocator.free_blocks
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cpu_allocator.free(block)
|
||||
|
||||
|
||||
def test_allocate():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
for i in range(num_gpu_blocks):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
# Use watermark to reserve one gpu block.
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=1 / num_gpu_blocks)
|
||||
for i in range(num_gpu_blocks - 1):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
|
||||
def test_allocate_encoder_decoder():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_req_per_seq_group = 2
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
for i in range(num_gpu_blocks // block_req_per_seq_group):
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
str(i),
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
# Use watermark to reserve one gpu block.
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=1 / num_gpu_blocks)
|
||||
for i in range((num_gpu_blocks - 1) // block_req_per_seq_group):
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
str(i),
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
|
||||
def test_allocate_encoder_decoder_fails_with_swa():
|
||||
# SWA short for sliding window attention
|
||||
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0,
|
||||
sliding_window=5) # swa
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
"0",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
|
||||
# Assert that can_allocate() fails due to SWA
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
block_manager.can_allocate(seq_group)
|
||||
|
||||
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
|
||||
|
||||
# Assert that allocate() fails due to SWA
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
|
||||
|
||||
|
||||
def test_allocate_encoder_decoder_fails_with_prefix_caching():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0,
|
||||
enable_caching=True) # Prefix cache
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
"0",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
|
||||
# Assert that can_allocate() fails due to prefix caching
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
block_manager.can_allocate(seq_group)
|
||||
|
||||
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
|
||||
|
||||
# Assert that allocate() fails due to prefix caching
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
|
||||
|
||||
|
||||
def test_append_slot_single_seq():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate single seq to gpu block.
|
||||
prompt, seq_group = create_dummy_prompt("1", block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Nothing to append. Sequence has no new logical blocks.
|
||||
assert block_manager.can_append_slots(seq_group)
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert not block_manager.append_slots(prompt)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_blocks == after_blocks
|
||||
|
||||
# Add block_size number of new tokens and append slot.
|
||||
for i in range(block_size):
|
||||
token_id = i + 5
|
||||
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
|
||||
assert block_manager.can_append_slots(seq_group)
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert not block_manager.append_slots(prompt)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_blocks - after_blocks == 1
|
||||
|
||||
|
||||
def test_append_slot_cow():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size=block_size,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate prompt to gpu block. There is one slot left in the block.
|
||||
prompt = Sequence(seq_id=1,
|
||||
inputs={
|
||||
"prompt": "one two three",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
},
|
||||
block_size=block_size)
|
||||
|
||||
# Fork the sequence, such that a COW will be required when we append a new
|
||||
# token id.
|
||||
child = prompt.fork(new_seq_id=2)
|
||||
|
||||
# Allocate space for the sequence group.
|
||||
seq_group = SequenceGroup(request_id="1",
|
||||
seqs=[prompt, child],
|
||||
arrival_time=time.time(),
|
||||
sampling_params=SamplingParams())
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Fork and append a new token id. We expect a COW to be scheduled.
|
||||
token_id = 4
|
||||
child.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.fork(prompt, child)
|
||||
|
||||
assert block_manager.can_append_slots(seq_group)
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
|
||||
cows = block_manager.append_slots(child)
|
||||
assert cows
|
||||
dict_cows = defaultdict(list)
|
||||
for src_block, dst_block in cows:
|
||||
dict_cows[src_block].append(dst_block)
|
||||
for src_block, dst_blocks in dict_cows.items():
|
||||
assert src_block not in dst_blocks
|
||||
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_blocks - after_blocks == 1
|
||||
|
||||
|
||||
def test_fork():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
prompt, seq_group = create_dummy_prompt("1",
|
||||
block_size - 1,
|
||||
block_size=block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Fork prompt and copy block tables.
|
||||
child = prompt.fork(2)
|
||||
block_manager.fork(prompt, child)
|
||||
assert block_manager.get_block_table(
|
||||
prompt) == block_manager.get_block_table(child)
|
||||
token_id = 4
|
||||
# Append token to child. Block is shared so copy on write occurs.
|
||||
child.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.append_slots(child)
|
||||
assert block_manager.get_block_table(
|
||||
prompt) != block_manager.get_block_table(child)
|
||||
|
||||
|
||||
def test_swap():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
|
||||
prompt.status = SequenceStatus.WAITING
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Emulate a forward pass by appending a single token.
|
||||
# The block manager then knows how many unprocessed
|
||||
# tokens will be written in the next forward pass.
|
||||
token_id = 0
|
||||
prompt.status = SequenceStatus.RUNNING
|
||||
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
|
||||
# Swap seq group from GPU -> CPU.
|
||||
gpu_blocks = block_manager.get_block_table(prompt)
|
||||
assert block_manager.can_swap_out(seq_group)
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_out(seq_group)
|
||||
assert [x[0] for x in mapping] == gpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
||||
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
|
||||
prompt.status = SequenceStatus.SWAPPED
|
||||
|
||||
# Swap seq group from CPU -> GPU.
|
||||
cpu_blocks = block_manager.get_block_table(prompt)
|
||||
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_in(seq_group)
|
||||
assert [x[0] for x in mapping] == cpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
|
||||
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
|
||||
|
||||
|
||||
def test_swap_encoder_decoder():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
decoder_prompt, encoder_prompt, seq_group = \
|
||||
create_dummy_prompt_encoder_decoder(
|
||||
"1",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
decoder_prompt.status = SequenceStatus.WAITING
|
||||
encoder_prompt.status = SequenceStatus.WAITING
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Emulate a forward pass by appending a single token.
|
||||
# The block manager then knows how many unprocessed
|
||||
# tokens will be written in the next forward pass.
|
||||
token_id = 0
|
||||
decoder_prompt.status = SequenceStatus.RUNNING
|
||||
decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
|
||||
# Swap encoder/decoder seq group from GPU -> CPU.
|
||||
decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt)
|
||||
cross_gpu_blocks = block_manager.get_cross_block_table(seq_group)
|
||||
gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks
|
||||
assert block_manager.can_swap_out(seq_group)
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_out(seq_group)
|
||||
assert [x[0] for x in mapping] == gpu_blocks
|
||||
#assert list(mapping.keys()) == gpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
||||
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
|
||||
decoder_prompt.status = SequenceStatus.SWAPPED
|
||||
|
||||
# Swap encoder/decoder seq group from CPU -> GPU.
|
||||
decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt)
|
||||
cross_cpu_blocks = block_manager.get_cross_block_table(seq_group)
|
||||
cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks
|
||||
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_in(seq_group)
|
||||
assert [x[0] for x in mapping] == cpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
|
||||
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
|
||||
|
||||
|
||||
def test_free():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
prompt, seq_group = create_dummy_prompt("1", block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Free allocated seq.
|
||||
prompt_blocks = len(block_manager.get_block_table(prompt))
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
block_manager.free(prompt)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert after_blocks == before_blocks + prompt_blocks
|
||||
|
||||
# Block table for freed seq is deleted.
|
||||
with pytest.raises(KeyError):
|
||||
block_manager.get_block_table(prompt)
|
||||
|
||||
|
||||
def test_free_encoder_decoder():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
decoder_prompt, encoder_prompt, seq_group = \
|
||||
create_dummy_prompt_encoder_decoder(
|
||||
"1",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Free allocated seq.
|
||||
decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt))
|
||||
encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group))
|
||||
prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
block_manager.free(decoder_prompt)
|
||||
block_manager.free_cross(seq_group)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert after_blocks == before_blocks + prompt_blocks
|
||||
|
||||
# Block table for freed encoder & decoder seq's are deleted.
|
||||
with pytest.raises(KeyError):
|
||||
block_manager.get_block_table(decoder_prompt)
|
||||
|
||||
# Block table for freed encoder & decoder seq's are deleted.
|
||||
with pytest.raises(KeyError):
|
||||
block_manager.get_block_table(encoder_prompt)
|
||||
|
||||
|
||||
def test_reset():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same seq group on all available gpu blocks.
|
||||
original_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
for i in range(num_gpu_blocks):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.get_num_free_gpu_blocks() == 0
|
||||
|
||||
# Resetting block manager frees all allocated blocks.
|
||||
block_manager.reset()
|
||||
assert block_manager.get_num_free_gpu_blocks() == original_blocks
|
||||
|
||||
|
||||
def test_reset_encoder_decoder():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_req_per_seq_group = 2
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same seq group on all available gpu blocks.
|
||||
original_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
for i in range(num_gpu_blocks // block_req_per_seq_group):
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
f"{i}",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.get_num_free_gpu_blocks() == 0
|
||||
|
||||
# Resetting block manager frees all allocated blocks.
|
||||
block_manager.reset()
|
||||
assert block_manager.get_num_free_gpu_blocks() == original_blocks
|
||||
|
||||
|
||||
def test_sliding_window_multi_seq():
|
||||
"""
|
||||
Tests that memory allocation and deallocation is handled
|
||||
correctly with multiple sequences that exceed the sliding
|
||||
window's capacity.
|
||||
"""
|
||||
block_size = 1
|
||||
num_cpu_blocks = 8
|
||||
num_gpu_blocks = 8
|
||||
sliding_window = 2
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
sliding_window=sliding_window,
|
||||
watermark=0)
|
||||
|
||||
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
|
||||
|
||||
parent = Sequence(seq_id=1,
|
||||
inputs={
|
||||
"prompt": "one two three",
|
||||
"prompt_token_ids": [0, 1, 2],
|
||||
},
|
||||
block_size=block_size)
|
||||
seq_group = SequenceGroup(request_id="1",
|
||||
seqs=[parent],
|
||||
arrival_time=time.time(),
|
||||
sampling_params=SamplingParams(),
|
||||
lora_request=None)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# the parent seq has len 3, but since sliding_window is 2,
|
||||
# we will use at most 2 blocks
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window
|
||||
|
||||
# Fork prompt and copy block tables.
|
||||
child = parent.fork(2)
|
||||
block_manager.fork(parent, child)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# forking does not increase memory consumption
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window
|
||||
|
||||
# assert both parent and child share all blocks
|
||||
assert block_manager.get_block_table(
|
||||
parent) == block_manager.get_block_table(child)
|
||||
|
||||
token_id = 4
|
||||
# Append token to child. Block is shared so copy on write occurs.
|
||||
child.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.append_slots(child)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# we will use now one block more. Each seq will use 2 blocks,
|
||||
# but only one can be shared
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window - 1
|
||||
|
||||
token_id = 5
|
||||
parent.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.append_slots(parent)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# no change, because both sequences are still just sharing one block
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window - 1
|
||||
|
||||
block_table_parent = block_manager.get_block_table(parent)
|
||||
block_table_child = block_manager.get_block_table(child)
|
||||
|
||||
assert block_table_parent != block_table_child
|
||||
|
||||
# assert both blocks are sharing the second-last block
|
||||
assert block_table_parent[-2] == block_table_child[-2]
|
||||
|
||||
# now let's clean up...
|
||||
block_manager.free(parent)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# We have freed one seq, reducing the ref count of two blocks by one.
|
||||
# One of the two was only used by the parent seq, so this is now free.
|
||||
# The child seq still consumes sliding_window blocks
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window
|
||||
|
||||
# free all blocks
|
||||
block_manager.free(child)
|
||||
|
||||
# assert all blocks are free now
|
||||
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
|
||||
|
||||
|
||||
def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
|
||||
"""When prefix cache and chunked prefill are enabled, the block manager
|
||||
should only mark a chunk of blocks as computed instead of all blocks.
|
||||
"""
|
||||
|
||||
block_size = 4
|
||||
num_cpu_blocks = 0
|
||||
num_gpu_blocks = 16
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_gpu_blocks,
|
||||
num_cpu_blocks,
|
||||
watermark=0,
|
||||
enable_caching=True)
|
||||
|
||||
# Set prompt size to have num_gpu_blocks - 1 full blocks.
|
||||
prompt_length = block_size * num_gpu_blocks - 1
|
||||
|
||||
# Allocate (reserve) all blocks.
|
||||
_, seq_group = create_dummy_prompt("0",
|
||||
prompt_length,
|
||||
block_size=block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
assert seq_group.seqs[0].n_blocks == num_gpu_blocks
|
||||
|
||||
# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
|
||||
token_chunk_size = int(block_size * 2.5)
|
||||
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
|
||||
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
|
||||
assert len(computed_blocks) == 2
|
||||
|
||||
# Actual computed tokens.
|
||||
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)
|
||||
|
||||
# 2nd chunk: Complete 3rd block and additional 4 blocks.
|
||||
token_chunk_size = int(block_size * 4.5)
|
||||
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
|
||||
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
|
||||
assert len(computed_blocks) == 7
|
@ -8,7 +8,6 @@ from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.sequence import Logprob, SequenceGroup
|
||||
|
||||
from ..utils import check_deprecated_block_manager_usage
|
||||
from .utils import create_dummy_prompt
|
||||
|
||||
|
||||
@ -28,25 +27,16 @@ def schedule_and_update_computed_tokens(scheduler):
|
||||
return metas, out
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def check_deprecated_block_manager():
|
||||
check_deprecated_block_manager_usage(
|
||||
'tests/core/test_chunked_prefill_scheduler.py')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_simple(use_v2_block_manager: bool):
|
||||
def test_simple():
|
||||
"""Verify basic scheduling works."""
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_model_len = 16
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_batched_tokens,
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
num_seq_group,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
@ -81,8 +71,7 @@ def test_simple(use_v2_block_manager: bool):
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_chunk(use_v2_block_manager: bool):
|
||||
def test_chunk():
|
||||
"""Verify prefills are chunked properly."""
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
@ -93,7 +82,7 @@ def test_chunk(use_v2_block_manager: bool):
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 32
|
||||
cache_config.num_gpu_blocks = 32
|
||||
@ -131,8 +120,7 @@ def test_chunk(use_v2_block_manager: bool):
|
||||
assert out.num_batched_tokens == 57
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_complex(use_v2_block_manager: bool):
|
||||
def test_complex():
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 80
|
||||
@ -142,7 +130,7 @@ def test_complex(use_v2_block_manager: bool):
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 64
|
||||
cache_config.num_gpu_blocks = 64
|
||||
@ -201,8 +189,7 @@ def test_complex(use_v2_block_manager: bool):
|
||||
assert running[2].is_prefill()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_maximal_decoding(use_v2_block_manager: bool):
|
||||
def test_maximal_decoding():
|
||||
"""Verify decoding requests are prioritized."""
|
||||
block_size = 4
|
||||
max_seqs = 2
|
||||
@ -213,7 +200,7 @@ def test_maximal_decoding(use_v2_block_manager: bool):
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
@ -295,8 +282,7 @@ def test_maximal_decoding(use_v2_block_manager: bool):
|
||||
assert out.num_batched_tokens == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_prompt_limit(use_v2_block_manager: bool):
|
||||
def test_prompt_limit():
|
||||
"""Verify max_num_batched_tokens < max_model_len is possible."""
|
||||
block_size = 4
|
||||
max_seqs = 32
|
||||
@ -307,7 +293,7 @@ def test_prompt_limit(use_v2_block_manager: bool):
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 16
|
||||
cache_config.num_gpu_blocks = 16
|
||||
@ -330,8 +316,7 @@ def test_prompt_limit(use_v2_block_manager: bool):
|
||||
assert out.num_batched_tokens == 32
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_prompt_limit_exceed(use_v2_block_manager: bool):
|
||||
def test_prompt_limit_exceed():
|
||||
block_size = 4
|
||||
max_seqs = 64
|
||||
max_model_len = 32
|
||||
@ -356,8 +341,7 @@ def test_prompt_limit_exceed(use_v2_block_manager: bool):
|
||||
assert out.ignored_seq_groups[0] == seq_group
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_swap(use_v2_block_manager: bool):
|
||||
def test_swap():
|
||||
"""Verify swapping works with chunked prefill requests"""
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
@ -368,7 +352,7 @@ def test_swap(use_v2_block_manager: bool):
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 16
|
||||
cache_config.num_gpu_blocks = 16
|
||||
@ -414,8 +398,7 @@ def test_swap(use_v2_block_manager: bool):
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool):
|
||||
def test_running_prefill_prioritized_over_swap():
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
@ -425,7 +408,7 @@ def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool):
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 32
|
||||
cache_config.num_gpu_blocks = 32
|
||||
@ -508,8 +491,7 @@ def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool):
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_chunked_prefill_preempt(use_v2_block_manager: bool):
|
||||
def test_chunked_prefill_preempt():
|
||||
"""Verify preempt works with chunked prefill requests"""
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
@ -520,7 +502,7 @@ def test_chunked_prefill_preempt(use_v2_block_manager: bool):
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 16
|
||||
cache_config.num_gpu_blocks = 16
|
||||
@ -575,8 +557,7 @@ def test_chunked_prefill_preempt(use_v2_block_manager: bool):
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_chunked_prefill_max_seqs(use_v2_block_manager: bool):
|
||||
def test_chunked_prefill_max_seqs():
|
||||
block_size = 4
|
||||
max_seqs = 2
|
||||
max_model_len = 80
|
||||
@ -586,7 +567,7 @@ def test_chunked_prefill_max_seqs(use_v2_block_manager: bool):
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 128
|
||||
cache_config.num_gpu_blocks = 128
|
||||
@ -629,8 +610,7 @@ def test_chunked_prefill_max_seqs(use_v2_block_manager: bool):
|
||||
assert not running[1].is_prefill()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_perfix_caching(use_v2_block_manager: bool):
|
||||
def test_perfix_caching():
|
||||
"""Verify allocating full blocks when prefix caching is enabled."""
|
||||
block_size = 4
|
||||
max_seqs = 10
|
||||
@ -641,7 +621,7 @@ def test_perfix_caching(use_v2_block_manager: bool):
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size,
|
||||
1.0,
|
||||
1,
|
||||
|
@ -31,7 +31,6 @@ def test_num_computed_tokens_update(num_scheduler_steps: int,
|
||||
# Make a vllm engine
|
||||
runner = VllmRunner(model_name=MODEL,
|
||||
gpu_memory_utilization=0.7,
|
||||
use_v2_block_manager=True,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
enforce_eager=enforce_eager)
|
||||
|
@ -3,7 +3,7 @@ from collections import deque
|
||||
from typing import List, Set, Tuple
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest # noqa
|
||||
from torch import Use # noqa
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
@ -12,23 +12,18 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
|
||||
from ..utils import check_deprecated_block_manager_usage
|
||||
from .utils import (append_new_token, append_new_token_seq_group,
|
||||
create_dummy_prompt, get_sequence_groups,
|
||||
schedule_and_update_computed_tokens)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def check_deprecated_block_manager():
|
||||
check_deprecated_block_manager_usage(
|
||||
"tests/core/test_chunked_prefill_scheduler.py")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_scheduler_add_seq_group(use_v2_block_manager: bool):
|
||||
def test_scheduler_add_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(
|
||||
100, 64, 1, use_v2_block_manager=use_v2_block_manager)
|
||||
100,
|
||||
64,
|
||||
1,
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
|
||||
cache_config.num_cpu_blocks = 4
|
||||
cache_config.num_gpu_blocks = 4
|
||||
@ -44,11 +39,13 @@ def test_scheduler_add_seq_group(use_v2_block_manager: bool):
|
||||
assert scheduler.get_num_unfinished_seq_groups() == i + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_scheduler_abort_seq_group(use_v2_block_manager: bool):
|
||||
def test_scheduler_abort_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(
|
||||
100, 64, 1, use_v2_block_manager=use_v2_block_manager)
|
||||
100,
|
||||
64,
|
||||
1,
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 4
|
||||
cache_config.num_gpu_blocks = 4
|
||||
@ -68,8 +65,7 @@ def test_scheduler_abort_seq_group(use_v2_block_manager: bool):
|
||||
assert scheduler.get_num_unfinished_seq_groups() == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_scheduler_schedule_simple(use_v2_block_manager: bool):
|
||||
def test_scheduler_schedule_simple():
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_model_len = 16
|
||||
@ -77,7 +73,7 @@ def test_scheduler_schedule_simple(use_v2_block_manager: bool):
|
||||
64,
|
||||
num_seq_group,
|
||||
max_model_len,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
@ -112,8 +108,7 @@ def test_scheduler_schedule_simple(use_v2_block_manager: bool):
|
||||
append_new_token(out, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_scheduler_prefill_prioritized(use_v2_block_manager: bool):
|
||||
def test_scheduler_prefill_prioritized():
|
||||
"""Verify running batched tokens are not applied to prefill requests."""
|
||||
block_size = 4
|
||||
max_model_len = 30
|
||||
@ -122,7 +117,7 @@ def test_scheduler_prefill_prioritized(use_v2_block_manager: bool):
|
||||
max_batched_num_tokens,
|
||||
2,
|
||||
max_model_len,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 16
|
||||
cache_config.num_gpu_blocks = 16
|
||||
@ -146,12 +141,14 @@ def test_scheduler_prefill_prioritized(use_v2_block_manager: bool):
|
||||
assert get_sequence_groups(out) == [seq_group_b]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool):
|
||||
def test_scheduler_schedule_preempt_abort():
|
||||
block_size = 4
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(
|
||||
64, 2, max_model_len, use_v2_block_manager=use_v2_block_manager)
|
||||
64,
|
||||
2,
|
||||
max_model_len,
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 2
|
||||
cache_config.num_gpu_blocks = 2
|
||||
@ -201,8 +198,7 @@ def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool):
|
||||
assert scheduler.get_num_unfinished_seq_groups() == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_scheduler_max_seqs(use_v2_block_manager: bool):
|
||||
def test_scheduler_max_seqs():
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_seq_group = 2
|
||||
@ -211,7 +207,7 @@ def test_scheduler_max_seqs(use_v2_block_manager: bool):
|
||||
64,
|
||||
max_seq_group,
|
||||
max_model_len,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
@ -249,15 +245,14 @@ def test_scheduler_max_seqs(use_v2_block_manager: bool):
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_scheduler_delay_factor(use_v2_block_manager: bool):
|
||||
def test_scheduler_delay_factor():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(
|
||||
100,
|
||||
64,
|
||||
16,
|
||||
delay_factor=0.5,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
@ -294,12 +289,10 @@ def test_scheduler_delay_factor(use_v2_block_manager: bool):
|
||||
append_new_token(out, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_swapped_out_prioritized(use_v2_block_manager: bool):
|
||||
def test_swapped_out_prioritized():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(max_num_seqs=6,
|
||||
block_size=block_size,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=64)
|
||||
# best_of=2 * 3 == 6 sequences.
|
||||
@ -351,7 +344,6 @@ def initialize_scheduler(
|
||||
max_token_budget=1000,
|
||||
max_model_len=1000,
|
||||
lora_config=None,
|
||||
use_v2_block_manager=False,
|
||||
block_size=4,
|
||||
num_cpu_blocks=8,
|
||||
num_gpu_blocks=8,
|
||||
@ -361,7 +353,7 @@ def initialize_scheduler(
|
||||
max_token_budget,
|
||||
max_num_seqs,
|
||||
max_model_len,
|
||||
use_v2_block_manager=use_v2_block_manager)
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
@ -386,15 +378,12 @@ def add_token_budget(budget: SchedulingBudget,
|
||||
budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool):
|
||||
def test_prefill_schedule_max_prompt_len():
|
||||
"""
|
||||
Test prompt longer than max_prompt_len is aborted.
|
||||
"""
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(max_model_len=30,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size)
|
||||
scheduler = initialize_scheduler(max_model_len=30, block_size=block_size)
|
||||
_, seq_group = create_dummy_prompt("0",
|
||||
prompt_length=60,
|
||||
block_size=block_size)
|
||||
@ -409,14 +398,12 @@ def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool):
|
||||
assert len(remaining_waiting) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_prefill_schedule_token_budget(use_v2_block_manager: bool):
|
||||
def test_prefill_schedule_token_budget():
|
||||
"""
|
||||
Test token budget respected.
|
||||
"""
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=64)
|
||||
budget = create_token_budget(token_budget=0)
|
||||
@ -446,8 +433,7 @@ def test_prefill_schedule_token_budget(use_v2_block_manager: bool):
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
# Test when current_batched_tokens respected.
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16)
|
||||
budget = create_token_budget(token_budget=60)
|
||||
@ -474,14 +460,12 @@ def test_prefill_schedule_token_budget(use_v2_block_manager: bool):
|
||||
assert len(remaining_waiting) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_prefill_schedule_max_seqs(use_v2_block_manager: bool):
|
||||
def test_prefill_schedule_max_seqs():
|
||||
"""
|
||||
Test max seq respected.
|
||||
"""
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=64)
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
@ -515,15 +499,13 @@ def test_prefill_schedule_max_seqs(use_v2_block_manager: bool):
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_prefill_schedule_max_lora(use_v2_block_manager: bool):
|
||||
def test_prefill_schedule_max_lora():
|
||||
"""
|
||||
Test max lora is respected and prioritized.
|
||||
"""
|
||||
block_size = 4
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=64)
|
||||
@ -570,14 +552,12 @@ def test_prefill_schedule_max_lora(use_v2_block_manager: bool):
|
||||
assert budget.num_batched_tokens == 60
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager):
|
||||
def test_prefill_schedule_no_block_manager_capacity():
|
||||
"""
|
||||
Test sequence cannot be scheduled due to block manager has no capacity.
|
||||
"""
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_gpu_blocks=128,
|
||||
num_cpu_blocks=128)
|
||||
budget = create_token_budget()
|
||||
@ -614,14 +594,12 @@ def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager):
|
||||
assert len(remaining_waiting) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_decode_schedule_preempted(use_v2_block_manager: bool):
|
||||
def test_decode_schedule_preempted():
|
||||
"""
|
||||
Test decodes cannot be scheduled and preempted.
|
||||
"""
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=64)
|
||||
curr_loras = None
|
||||
@ -660,14 +638,12 @@ def test_decode_schedule_preempted(use_v2_block_manager: bool):
|
||||
assert output.blocks_to_copy == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_decode_swap_beam_search(use_v2_block_manager: bool):
|
||||
def test_decode_swap_beam_search():
|
||||
"""
|
||||
Test best_of > 1 swap out blocks
|
||||
"""
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_gpu_blocks=64,
|
||||
num_cpu_blocks=64)
|
||||
curr_loras = None
|
||||
@ -716,14 +692,12 @@ def test_decode_swap_beam_search(use_v2_block_manager: bool):
|
||||
assert output.blocks_to_copy == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool):
|
||||
def test_schedule_decode_blocks_to_copy_update():
|
||||
"""
|
||||
Verify blocks_to_copy is updated.
|
||||
"""
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=4,
|
||||
scheduler = initialize_scheduler(block_size=4,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16)
|
||||
_, seq_group = create_dummy_prompt("1",
|
||||
@ -754,11 +728,9 @@ def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool):
|
||||
assert output.blocks_to_copy == [(2, 3)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_schedule_swapped_simple(use_v2_block_manager: bool):
|
||||
def test_schedule_swapped_simple():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size)
|
||||
scheduler = initialize_scheduler(block_size=block_size)
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
_, seq_group = create_dummy_prompt("1",
|
||||
@ -785,11 +757,9 @@ def test_schedule_swapped_simple(use_v2_block_manager: bool):
|
||||
assert blocks_to_swap_out == blocks_to_swap_in_reverse
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool):
|
||||
def test_schedule_swapped_max_token_budget():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=32,
|
||||
num_gpu_blocks=32)
|
||||
curr_loras = None
|
||||
@ -822,11 +792,9 @@ def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool):
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_schedule_swapped_max_seqs(use_v2_block_manager: bool):
|
||||
def test_schedule_swapped_max_seqs():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=64)
|
||||
curr_loras = None
|
||||
@ -859,12 +827,10 @@ def test_schedule_swapped_max_seqs(use_v2_block_manager: bool):
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_schedule_swapped_max_loras(use_v2_block_manager: bool):
|
||||
def test_schedule_swapped_max_loras():
|
||||
block_size = 4
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=32,
|
||||
num_gpu_blocks=32)
|
||||
@ -894,11 +860,9 @@ def test_schedule_swapped_max_loras(use_v2_block_manager: bool):
|
||||
assert len(curr_loras) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool):
|
||||
def test_schedule_swapped_cannot_swap_in():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=32,
|
||||
num_gpu_blocks=32)
|
||||
curr_loras = None
|
||||
@ -927,11 +891,9 @@ def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool):
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_infeasible_swap(use_v2_block_manager: bool):
|
||||
def test_infeasible_swap():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=32,
|
||||
num_gpu_blocks=32)
|
||||
curr_loras = None
|
||||
@ -961,11 +923,9 @@ def test_infeasible_swap(use_v2_block_manager: bool):
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||
def test_schedule_swapped_blocks_to_copy(use_v2_block_manager: bool):
|
||||
def test_schedule_swapped_blocks_to_copy():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager,
|
||||
block_size=block_size,
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=32,
|
||||
num_gpu_blocks=32)
|
||||
curr_loras = None
|
||||
|
@ -185,13 +185,14 @@ def test_metric_spec_decode(
|
||||
) -> None:
|
||||
k = 5
|
||||
|
||||
with vllm_runner(model,
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.4,
|
||||
speculative_model=model,
|
||||
num_speculative_tokens=k,
|
||||
use_v2_block_manager=True) as vllm_model:
|
||||
) as vllm_model:
|
||||
|
||||
# Force log interval to be 0 to catch all metrics.
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
|
||||
@ -242,7 +243,6 @@ def test_metric_spec_decode_interval(
|
||||
gpu_memory_utilization=0.4,
|
||||
speculative_model=model,
|
||||
num_speculative_tokens=k,
|
||||
use_v2_block_manager=True,
|
||||
enforce_eager=True)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
@ -17,7 +17,6 @@ NUM_PROMPTS = [10]
|
||||
|
||||
DEFAULT_SERVER_ARGS: List[str] = [
|
||||
"--disable-log-requests",
|
||||
"--use-v2-block-manager",
|
||||
"--worker-use-ray",
|
||||
"--gpu-memory-utilization",
|
||||
"0.85",
|
||||
|
@ -76,7 +76,6 @@ def test_multi_step_llm(
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
use_v2_block_manager=True,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
) as vllm_model:
|
||||
@ -169,7 +168,6 @@ def test_multi_step_llm_w_prompt_logprobs(
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
use_v2_block_manager=True,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
@ -305,7 +303,6 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
use_v2_block_manager=True,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
max_model_len=48,
|
||||
max_num_batched_tokens=48,
|
||||
@ -324,7 +321,6 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
use_v2_block_manager=True,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=True,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
|
@ -2,15 +2,9 @@
|
||||
|
||||
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from tests.utils import check_deprecated_block_manager_usage
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
from vllm.core.block_manager_v1 import CachedBlockAllocator
|
||||
from vllm.utils import Device
|
||||
|
||||
from ..models.utils import check_outputs_equal
|
||||
|
||||
@ -19,92 +13,11 @@ MODELS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def check_deprecated_block_manager():
|
||||
check_deprecated_block_manager_usage(
|
||||
'tests/prefix_caching/test_prefix_caching.py')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_blocks", [16])
|
||||
def test_block_allocator(
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
):
|
||||
block_hash = 1
|
||||
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
|
||||
|
||||
# Allocate two PysicalTokenBlocks with the same hash and check
|
||||
# that they are the same PhysicalTokenBlock
|
||||
first_block = block_allocator.allocate(block_hash, 0)
|
||||
second_block = block_allocator.allocate(block_hash, 0)
|
||||
assert (first_block == second_block)
|
||||
assert (second_block.ref_count == 2)
|
||||
|
||||
# Check metric: 1 hit of 2 queries
|
||||
assert block_allocator.get_prefix_cache_hit_rate() == 0.5
|
||||
|
||||
# Free the first_block and confirm that the ref_count is correctly
|
||||
# decremented on the second block
|
||||
block_allocator.free(first_block)
|
||||
assert (second_block.ref_count == 1)
|
||||
|
||||
# Free the second block
|
||||
block_allocator.free(second_block)
|
||||
|
||||
# Reallocate the first block and confirm that, even after the block
|
||||
# had its ref_count go to 0, we still get the same block back
|
||||
first_block = block_allocator.allocate(block_hash, 0)
|
||||
assert (first_block == second_block)
|
||||
assert (first_block.block_hash == block_hash)
|
||||
|
||||
# Allocate one more time to get 3/4 hit rate for easy checking
|
||||
block_allocator.allocate(block_hash, 0)
|
||||
assert block_allocator.get_prefix_cache_hit_rate() == 0.75
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_blocks", [16])
|
||||
def test_eviction(num_blocks: int, ):
|
||||
block_size = 16
|
||||
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
|
||||
blocks: List[PhysicalTokenBlock] = []
|
||||
|
||||
for i in range(num_blocks):
|
||||
# use i as the block_hash
|
||||
blocks.append(block_allocator.allocate(i, 0))
|
||||
|
||||
#Free all blocks
|
||||
for block in blocks:
|
||||
block_allocator.free(block)
|
||||
|
||||
# Allocate a new block and confirm that it's the first block freed.
|
||||
# I.E The Least Recently Used block
|
||||
new_block_hash = block_size
|
||||
new_block = block_allocator.allocate(new_block_hash, 0)
|
||||
assert (new_block == blocks[0])
|
||||
assert (new_block.block_hash == new_block_hash)
|
||||
|
||||
# Reallocate the second in blocks to remove it from the free list
|
||||
realloc_block_hash = 1
|
||||
realloc_block = block_allocator.allocate(realloc_block_hash, 0)
|
||||
assert (realloc_block == blocks[realloc_block_hash])
|
||||
assert (realloc_block.block_hash == realloc_block_hash)
|
||||
|
||||
# Allocate a new block and confirm that it's not the realloc_block,
|
||||
# since the realloc_block shouldn't be in the free list
|
||||
new_block_hash = block_size + 1
|
||||
new_block = block_allocator.allocate(new_block_hash, 0)
|
||||
assert (realloc_block != new_block)
|
||||
assert (new_block.block_hash == new_block_hash)
|
||||
assert (new_block.block_number == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("cached_position", [0, 1])
|
||||
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
|
||||
def test_mixed_requests(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@ -114,7 +27,6 @@ def test_mixed_requests(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
cached_position: int,
|
||||
use_v2_block_manager: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""
|
||||
@ -132,7 +44,6 @@ def test_mixed_requests(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enable_prefix_caching=True,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
) as vllm_model:
|
||||
# Run the first prompt so the cache is populated
|
||||
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)
|
||||
|
@ -1,26 +1,14 @@
|
||||
import pytest
|
||||
|
||||
from tests.utils import check_deprecated_block_manager_usage
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def check_deprecated_block_manager():
|
||||
check_deprecated_block_manager_usage(
|
||||
'tests/spec_decode/e2e/test_compatibility.py')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
@ -51,15 +39,10 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
@ -101,34 +84,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
|
||||
with pytest.raises(ValueError, match="cannot be larger than"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"use_v2_block_manager": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
|
||||
"""Verify that speculative decoding with block manager v1 fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="Speculative decoding requires usage of the V2"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
@ -43,9 +43,6 @@ PRECISION = "float32"
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -86,9 +83,6 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -143,9 +137,6 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -191,9 +182,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
@ -235,9 +223,6 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
@ -283,9 +268,6 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
|
@ -12,8 +12,6 @@ MAIN_MODEL = "JackFram/llama-68m"
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Verify equality when cuda graphs allowed.
|
||||
"enforce_eager": False,
|
||||
@ -57,9 +55,6 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
@ -111,9 +106,6 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
|
@ -17,9 +17,6 @@ from .conftest import run_equality_correctness_test_tp
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce-eager",
|
||||
|
||||
# Required for spec decode.
|
||||
"--use-v2-block-manager",
|
||||
"--tensor-parallel-size",
|
||||
"2"
|
||||
]])
|
||||
@ -74,9 +71,6 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce-eager",
|
||||
|
||||
# Required for spec decode.
|
||||
"--use_v2_block_manager",
|
||||
"--tensor_parallel_size",
|
||||
"2",
|
||||
|
||||
|
@ -19,9 +19,6 @@ SPEC_MODEL = "JackFram/llama-68m"
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce_eager",
|
||||
|
||||
# Required for spec decode.
|
||||
"--use-v2-block-manager",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
]])
|
||||
@ -71,9 +68,6 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce-eager",
|
||||
|
||||
# Required for spec decode.
|
||||
"--use-v2-block-manager",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
]])
|
||||
|
@ -14,9 +14,6 @@ from .conftest import run_equality_correctness_test
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -67,9 +64,6 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -119,9 +113,6 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -173,9 +164,6 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -251,8 +239,6 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
||||
"model_name": "JackFram/llama-160m",
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
|
@ -45,9 +45,6 @@ PRECISION = "float32"
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -93,9 +90,6 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -151,9 +145,6 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -204,9 +195,6 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
@ -253,9 +241,6 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
@ -306,9 +291,6 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
@ -356,9 +338,6 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
|
@ -47,9 +47,6 @@ PRECISION = "float32"
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -94,9 +91,6 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -149,9 +143,6 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -195,9 +186,6 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
@ -258,9 +246,6 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
@ -311,9 +296,6 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
@ -366,9 +348,6 @@ def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
@ -419,9 +398,6 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
@ -469,9 +445,6 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"speculative_model": SPEC_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
|
@ -55,9 +55,6 @@ from .conftest import (get_output_from_llm_generator,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
@ -124,9 +121,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@ -190,9 +184,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@ -246,9 +237,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
@ -303,9 +291,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@ -353,9 +338,6 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@ -404,9 +386,6 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
@ -454,9 +433,6 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
@ -514,9 +490,6 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -570,9 +543,6 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -611,9 +581,6 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -660,9 +627,6 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
|
@ -35,9 +35,6 @@ from .conftest import run_equality_correctness_test
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@ -82,9 +79,6 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@ -145,9 +139,6 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
@ -195,9 +186,6 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -254,9 +242,6 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -303,7 +288,6 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
|
@ -17,9 +17,6 @@ SPEC_MODEL = "JackFram/llama-160m"
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# speculative model
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
|
||||
|
@ -678,12 +678,3 @@ def get_client_text_logprob_generations(
|
||||
return [(text_generations, text,
|
||||
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
||||
for completion in completions for x in completion.choices]
|
||||
|
||||
|
||||
def check_deprecated_block_manager_usage(test_name: str):
|
||||
assert envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1 is True, (
|
||||
f"To allow the use of deprecated BlockSpaceManagerV1, set the "
|
||||
f"environment variable VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. "
|
||||
f"You can run the tests with: "
|
||||
f"`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest {test_name}`" #noqa
|
||||
)
|
||||
|
@ -305,8 +305,6 @@ class FlashAttentionMetadataBuilder(
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
self.use_v2_block_manager = (
|
||||
input_builder.scheduler_config.use_v2_block_manager)
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
@ -355,9 +353,9 @@ class FlashAttentionMetadataBuilder(
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(
|
||||
is_prompt, query_len, context_len, self.sliding_window,
|
||||
self.use_v2_block_manager)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
@ -475,8 +475,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
self.use_v2_block_manager = (
|
||||
input_builder.scheduler_config.use_v2_block_manager)
|
||||
|
||||
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
||||
# for the precise definition of the following fields.
|
||||
@ -542,9 +540,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
|
||||
# Compute slot mapping.
|
||||
start_idx = compute_slot_mapping_start_idx(
|
||||
is_prompt, query_len, context_len, self.sliding_window,
|
||||
self.use_v2_block_manager)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
@ -38,18 +38,12 @@ def is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||
|
||||
|
||||
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
|
||||
context_len: int, sliding_window: int,
|
||||
use_v2_block_manager: bool):
|
||||
context_len: int, sliding_window: int):
|
||||
"""
|
||||
Compute the start index of slot mapping.
|
||||
"""
|
||||
start_idx = 0
|
||||
if is_prompt and sliding_window is not None:
|
||||
assert use_v2_block_manager or context_len == 0, (
|
||||
"Prefix caching is currently not supported with "
|
||||
"sliding window attention in V1 block manager")
|
||||
# When prefill, we use it to not write slots to kv cache
|
||||
# to save memory.
|
||||
start_idx = max(0, query_len - sliding_window)
|
||||
return start_idx
|
||||
|
||||
@ -138,8 +132,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
self.use_v2_block_manager = (
|
||||
input_builder.scheduler_config.use_v2_block_manager)
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
@ -180,9 +172,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(
|
||||
is_prompt, query_len, context_len, self.sliding_window,
|
||||
self.use_v2_block_manager)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
1
vllm/commit_id.py
Normal file
1
vllm/commit_id.py
Normal file
@ -0,0 +1 @@
|
||||
__commit__ = "93ec62b8556e279d2c050bdc1c3247831bd39466"
|
@ -949,7 +949,6 @@ class SchedulerConfig:
|
||||
iteration.
|
||||
max_model_len: Maximum length of a sequence (including prompt
|
||||
and generated text).
|
||||
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
|
||||
num_lookahead_slots: The number of slots to allocate per sequence per
|
||||
step, beyond the known token ids. This is used in speculative
|
||||
decoding to store KV activations of tokens which may or may not be
|
||||
@ -976,7 +975,6 @@ class SchedulerConfig:
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
use_v2_block_manager: bool = True,
|
||||
num_lookahead_slots: int = 0,
|
||||
delay_factor: float = 0.0,
|
||||
enable_chunked_prefill: bool = False,
|
||||
@ -1026,7 +1024,6 @@ class SchedulerConfig:
|
||||
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.use_v2_block_manager = use_v2_block_manager
|
||||
self.num_lookahead_slots = num_lookahead_slots
|
||||
self.delay_factor = delay_factor
|
||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||
@ -1067,18 +1064,6 @@ class SchedulerConfig:
|
||||
f"({self.num_scheduler_steps}) must be greater than or "
|
||||
"equal to 1.")
|
||||
|
||||
if (not self.use_v2_block_manager \
|
||||
and not envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1):
|
||||
raise ValueError(
|
||||
"The use of BlockSpaceManagerV1 is deprecated and will "
|
||||
"be removed in a future release. Please switch to "
|
||||
"BlockSpaceManagerV2 by setting --use-v2-block-manager to "
|
||||
"True. If you wish to suppress this error temporarily, "
|
||||
"you can set the environment variable "
|
||||
"`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. If your use "
|
||||
"case is not supported in BlockSpaceManagerV2, please "
|
||||
"file an issue with detailed information.")
|
||||
|
||||
@property
|
||||
def is_multi_step(self) -> bool:
|
||||
return self.num_scheduler_steps > 1
|
||||
@ -1137,7 +1122,6 @@ class SpeculativeConfig:
|
||||
speculative_disable_mqa_scorer: Optional[bool],
|
||||
speculative_max_model_len: Optional[int],
|
||||
enable_chunked_prefill: bool,
|
||||
use_v2_block_manager: bool,
|
||||
disable_log_stats: bool,
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
@ -1178,9 +1162,6 @@ class SpeculativeConfig:
|
||||
enable_chunked_prefill (bool): Whether vLLM is configured to use
|
||||
chunked prefill or not. Used for raising an error since its not
|
||||
yet compatible with spec decode.
|
||||
use_v2_block_manager (bool): Whether vLLM is configured to use the
|
||||
v2 block manager or not. Used for raising an error since the v2
|
||||
block manager is required with spec decode.
|
||||
speculative_disable_by_batch_size (Optional[int]): Disable
|
||||
speculative decoding for new incoming requests when the number
|
||||
of enqueue requests is larger than this value, if provided.
|
||||
@ -1231,11 +1212,6 @@ class SpeculativeConfig:
|
||||
"Speculative decoding and chunked prefill are "
|
||||
f"currently mutually exclusive ({enable_chunked_prefill=}).")
|
||||
|
||||
if not use_v2_block_manager:
|
||||
raise ValueError(
|
||||
"Speculative decoding requires usage of the V2 "
|
||||
"block manager. Enable it with --use-v2-block-manager.")
|
||||
|
||||
# TODO: The user should be able to specify revision/max model len
|
||||
# for the draft model. It is not currently supported.
|
||||
draft_revision = None
|
||||
|
@ -4,28 +4,6 @@ from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||
STR_NOT_IMPL_ENC_DEC_SWA)
|
||||
|
||||
|
||||
def _get_block_mgr_sliding_window_attr(block_mgr):
|
||||
'''
|
||||
BlockManagerV1 and BlockManagerV2 have slightly different
|
||||
members related to sliding window attention (SWA). This
|
||||
function extracts the appropriate member to use for determining
|
||||
whether SWA is enabled.
|
||||
|
||||
Arguments:
|
||||
|
||||
* block_mgr: BlockManagerV1 or BlockManagerV2 instance
|
||||
'''
|
||||
|
||||
if hasattr(block_mgr, 'block_sliding_window'):
|
||||
return block_mgr.block_sliding_window
|
||||
if hasattr(block_mgr, 'max_block_sliding_window'):
|
||||
return block_mgr.max_block_sliding_window
|
||||
|
||||
raise AttributeError("Block manager instance has neither " + \
|
||||
"block_sliding_window nor " + \
|
||||
"max_block_sliding_window attributes.")
|
||||
|
||||
|
||||
def check_no_caching_or_swa_for_blockmgr_encdec(
|
||||
block_mgr, seq_group: SequenceGroup) -> None:
|
||||
'''
|
||||
@ -41,7 +19,7 @@ def check_no_caching_or_swa_for_blockmgr_encdec(
|
||||
'''
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
if _get_block_mgr_sliding_window_attr(block_mgr) is not None:
|
||||
if block_mgr.max_block_sliding_window is not None:
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA)
|
||||
|
||||
if block_mgr.enable_caching:
|
||||
|
@ -17,7 +17,7 @@ SeqId = int
|
||||
EncoderSeqId = str
|
||||
|
||||
|
||||
class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
"""BlockSpaceManager which manages the allocation of KV cache.
|
||||
|
||||
It owns responsibility for allocation, swapping, allocating memory for
|
@ -1,743 +0,0 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import count, takewhile
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple
|
||||
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.core.block.common import CacheMetricData
|
||||
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
|
||||
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockAllocatorBase(ABC):
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
requested. When a block is freed, its reference count is decremented. If
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_total_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
|
||||
class CachedBlockAllocator(BlockAllocatorBase):
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
requested. When a block is freed, its reference count is decremented. If
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
self.current_num_blocks = 0
|
||||
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
|
||||
|
||||
self.evictor: Evictor = make_evictor(eviction_policy)
|
||||
|
||||
self.default_hash_ctr = count()
|
||||
|
||||
self.cache_metric_data = CacheMetricData()
|
||||
|
||||
def allocate_block(self, block_hash: int,
|
||||
num_hashed_tokens: int) -> PhysicalTokenBlock:
|
||||
if self.current_num_blocks == self.num_blocks:
|
||||
block = self.evictor.evict()
|
||||
block.block_hash = block_hash
|
||||
block.num_hashed_tokens = num_hashed_tokens
|
||||
return block
|
||||
block = PhysicalTokenBlock(device=self.device,
|
||||
block_number=self.current_num_blocks,
|
||||
block_size=self.block_size,
|
||||
block_hash=block_hash,
|
||||
num_hashed_tokens=num_hashed_tokens)
|
||||
self.current_num_blocks += 1
|
||||
return block
|
||||
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
if block_hash is None:
|
||||
block_hash = next(self.default_hash_ctr)
|
||||
|
||||
if block_hash in self.evictor:
|
||||
assert block_hash not in self.cached_blocks
|
||||
block = self.evictor.remove(block_hash)
|
||||
assert block.ref_count == 0
|
||||
self.cached_blocks[block_hash] = block
|
||||
|
||||
if block_hash in self.cached_blocks:
|
||||
self.cache_metric_data.query(hit=True)
|
||||
else:
|
||||
self.cache_metric_data.query(hit=False)
|
||||
self.cached_blocks[block_hash] = self.allocate_block(
|
||||
block_hash, num_hashed_tokens)
|
||||
block = self.cached_blocks[block_hash]
|
||||
assert block.block_hash == block_hash
|
||||
block.ref_count += 1
|
||||
return block
|
||||
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
if block.ref_count == 0:
|
||||
raise ValueError(f"Double free! {block} is already freed.")
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
assert block.block_hash not in self.evictor
|
||||
self.evictor.add(block)
|
||||
|
||||
# Remove the block from the cached_blocks
|
||||
del self.cached_blocks[block.block_hash]
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return (self.num_blocks - self.current_num_blocks +
|
||||
self.evictor.num_blocks)
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return self.num_blocks
|
||||
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
return block_hash in self.cached_blocks or block_hash in self.evictor
|
||||
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
# Update the hash of block and the cached_blocks dictionary.
|
||||
assert not self.contains_block(block_hash)
|
||||
old_hash = block.block_hash
|
||||
block.block_hash = block_hash
|
||||
del self.cached_blocks[old_hash]
|
||||
self.cached_blocks[block_hash] = block
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return self.cache_metric_data.get_hit_rate()
|
||||
|
||||
|
||||
class UncachedBlockAllocator(BlockAllocatorBase):
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
requested. When a block is freed, its reference count is decremented. If
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
# Initialize the free blocks.
|
||||
self.free_blocks: List[PhysicalTokenBlock] = []
|
||||
for i in range(num_blocks):
|
||||
block = PhysicalTokenBlock(device=device,
|
||||
block_number=i,
|
||||
block_size=block_size,
|
||||
block_hash=-1,
|
||||
num_hashed_tokens=0)
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
if not self.free_blocks:
|
||||
raise ValueError("Out of memory! No free blocks are available.")
|
||||
block = self.free_blocks.pop()
|
||||
block.ref_count = 1
|
||||
return block
|
||||
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
if block.ref_count == 0:
|
||||
raise ValueError(f"Double free! {block} is already freed.")
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self.free_blocks)
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return self.num_blocks
|
||||
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Invalid codepath for uncached block allocator.")
|
||||
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
raise NotImplementedError(
|
||||
"Invalid codepath for uncached block allocator.")
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return -1
|
||||
|
||||
|
||||
class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
"""Manages the mapping between logical and physical token blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
if enable_caching and sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
"Sliding window is not allowed with prefix caching enabled!")
|
||||
|
||||
self.block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
# Round up to nearest block size to regularize sliding window
|
||||
# allocation sizes.
|
||||
self.block_sliding_window = math.ceil(sliding_window / block_size)
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
if self.enable_caching:
|
||||
logger.info("Automatic prefix caching is enabled.")
|
||||
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
|
||||
Device.GPU, block_size, num_gpu_blocks)
|
||||
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
|
||||
Device.CPU, block_size, num_cpu_blocks)
|
||||
else:
|
||||
self.gpu_allocator = UncachedBlockAllocator(
|
||||
Device.GPU, block_size, num_gpu_blocks)
|
||||
self.cpu_allocator = UncachedBlockAllocator(
|
||||
Device.CPU, block_size, num_cpu_blocks)
|
||||
# Mapping: seq_id -> BlockTable.
|
||||
self.block_tables: Dict[int, BlockTable] = {}
|
||||
|
||||
# Mapping: req_id -> BlockTable
|
||||
# Note that each SequenceGroup has a unique
|
||||
# request ID
|
||||
self.cross_block_tables: Dict[str, BlockTable] = {}
|
||||
|
||||
def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
|
||||
return 0 if seq is None else seq.n_blocks
|
||||
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
|
||||
assert (num_lookahead_slots == 0
|
||||
), "lookahead allocation not supported in BlockSpaceManagerV1"
|
||||
|
||||
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||
|
||||
self_num_required_blocks = self._get_seq_num_required_blocks(
|
||||
seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
|
||||
cross_num_required_blocks = self._get_seq_num_required_blocks(
|
||||
seq_group.get_encoder_seq())
|
||||
num_required_blocks = self_num_required_blocks + \
|
||||
cross_num_required_blocks
|
||||
|
||||
if self.block_sliding_window is not None:
|
||||
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
if (self.num_total_gpu_blocks - num_required_blocks <
|
||||
self.watermark_blocks):
|
||||
return AllocStatus.NEVER
|
||||
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def _allocate_sequence(self, \
|
||||
seq: Optional[Sequence], \
|
||||
ref_count: int, \
|
||||
is_encoder_decoder: bool = True) -> BlockTable:
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
num_prompt_blocks = self._get_seq_num_required_blocks(seq)
|
||||
|
||||
block_table: BlockTable = BlockTable()
|
||||
assert seq is not None
|
||||
for logical_idx in range(num_prompt_blocks):
|
||||
if (self.block_sliding_window is not None
|
||||
and logical_idx >= self.block_sliding_window):
|
||||
block = block_table[logical_idx % self.block_sliding_window]
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = ref_count
|
||||
elif not is_encoder_decoder and self.enable_caching:
|
||||
block = self.gpu_allocator.allocate(
|
||||
seq.hash_of_block(logical_idx),
|
||||
seq.num_hashed_tokens_of_block(logical_idx))
|
||||
else:
|
||||
block = self.gpu_allocator.allocate()
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = ref_count
|
||||
block_table.append(block)
|
||||
|
||||
return block_table
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
is_encoder_decoder = seq_group.is_encoder_decoder()
|
||||
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||
|
||||
# Allocate decoder sequences
|
||||
#
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# decoder prompt.
|
||||
wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||
seq = wait_seqs[0]
|
||||
block_table: BlockTable = \
|
||||
self._allocate_sequence(seq,
|
||||
seq_group.num_seqs(),
|
||||
is_encoder_decoder)
|
||||
|
||||
# Assign the self-attention block tables for each sequence.
|
||||
if len(wait_seqs) == 1:
|
||||
self.block_tables[seq.seq_id] = block_table
|
||||
else:
|
||||
for seq in wait_seqs:
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
|
||||
# Allocate encoder sequence
|
||||
if is_encoder_decoder:
|
||||
# A SequenceGroup has only a single encoder sequence (at most),
|
||||
# thus allocate with a ref count of 1
|
||||
block_table = self._allocate_sequence(seq_group.get_encoder_seq(),
|
||||
1, is_encoder_decoder)
|
||||
# Assign the cross-attention block table for the SequenceGroup.
|
||||
self.cross_block_tables[seq_group.request_id] = block_table
|
||||
|
||||
def can_append_slots(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> bool:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "lookahead allocation not supported in BlockSpaceManagerV1"
|
||||
|
||||
# Simple heuristic: If there is at least one free block
|
||||
# for each sequence, we can append.
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
return num_seqs <= num_free_gpu_blocks
|
||||
|
||||
def _promote_last_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
last_block: PhysicalTokenBlock,
|
||||
) -> PhysicalTokenBlock:
|
||||
assert self.enable_caching
|
||||
|
||||
# Compute a new hash for the block so that it can be shared by other
|
||||
# Sequences
|
||||
new_hash = seq.hash_of_block(seq.n_blocks - 1)
|
||||
|
||||
# if new_hash is already in the cached table, then free last_block
|
||||
# and return the cached version
|
||||
if self.gpu_allocator.contains_block(new_hash):
|
||||
self.gpu_allocator.free(last_block)
|
||||
return self.gpu_allocator.allocate(new_hash)
|
||||
else:
|
||||
self.gpu_allocator.update_hash(new_hash, last_block)
|
||||
return last_block
|
||||
|
||||
def _is_last_block_full(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> bool:
|
||||
token_ids_len = seq.data.get_len()
|
||||
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
|
||||
|
||||
def _maybe_promote_last_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
last_block: PhysicalTokenBlock,
|
||||
) -> PhysicalTokenBlock:
|
||||
if self._is_last_block_full(seq):
|
||||
return self._promote_last_block(seq, last_block)
|
||||
else:
|
||||
return last_block
|
||||
|
||||
def _allocate_last_physical_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> PhysicalTokenBlock:
|
||||
# Called before a new block is appended.
|
||||
# This is in charge of allocating a new physical block (to be appended).
|
||||
|
||||
# None if the last block is not full. Otherwise, we set it to the
|
||||
# content hash.
|
||||
if not self.enable_caching:
|
||||
return self.gpu_allocator.allocate()
|
||||
block_hash: Optional[int] = None
|
||||
n_blocks = seq.n_blocks
|
||||
if (self._is_last_block_full(seq)):
|
||||
block_hash = seq.hash_of_block(n_blocks - 1)
|
||||
num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)
|
||||
|
||||
# num_hashed_tokens is used to compute future hashes
|
||||
# (e.g. in the hashing function, it is used to ask the sequence for
|
||||
# prefix tokens)
|
||||
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
|
||||
|
||||
# If the block_hash is None, then the block is not full.
|
||||
# If the block is not full, then we expect it to have a refcount of 1.
|
||||
if block_hash is None:
|
||||
assert new_block.ref_count == 1
|
||||
return new_block
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int = 0,
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""Allocate a physical slot for a new token."""
|
||||
n_blocks = seq.n_blocks
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
# If we need to allocate a new physical block
|
||||
if len(block_table) < n_blocks:
|
||||
# Currently this code only supports adding one physical block
|
||||
assert len(block_table) == n_blocks - 1
|
||||
|
||||
if (self.block_sliding_window
|
||||
and len(block_table) >= self.block_sliding_window):
|
||||
# reuse a block
|
||||
block_table.append(block_table[len(block_table) %
|
||||
self.block_sliding_window])
|
||||
else:
|
||||
# The sequence hash a new logical block.
|
||||
# Allocate a new physical block.
|
||||
new_block = self._allocate_last_physical_block(seq)
|
||||
block_table.append(new_block)
|
||||
return []
|
||||
|
||||
# We want to append the token to the last physical block.
|
||||
last_block = block_table[-1]
|
||||
assert last_block.device == Device.GPU
|
||||
if last_block.ref_count == 1:
|
||||
# Not shared with other sequences. Appendable.
|
||||
if self.enable_caching:
|
||||
# If the last block is now complete, we may reuse an old block
|
||||
# to save memory.
|
||||
maybe_new_block = self._maybe_promote_last_block(
|
||||
seq, last_block)
|
||||
block_table[-1] = maybe_new_block
|
||||
return []
|
||||
else:
|
||||
# The last block is shared with other sequences.
|
||||
# Copy on Write: Allocate a new block and copy the tokens.
|
||||
new_block = self._allocate_last_physical_block(seq)
|
||||
|
||||
block_table[-1] = new_block
|
||||
self.gpu_allocator.free(last_block)
|
||||
return [(last_block.block_number, new_block.block_number)]
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
# NOTE: fork does not allocate a new physical block.
|
||||
# Thus, it is always safe from OOM.
|
||||
if parent_seq.seq_id not in self.block_tables:
|
||||
# Parent sequence has either been freed or never existed.
|
||||
return
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.copy()
|
||||
|
||||
# When using a sliding window, blocks will be eventually reused.
|
||||
# In this case the block tables will contain repeated blocks.
|
||||
# When forking, we must make sure that each block's `ref_count`
|
||||
# is only incremented by one, so we deduplicate them by wrapping
|
||||
# them in a set.
|
||||
for block in set(src_block_table):
|
||||
block.ref_count += 1
|
||||
|
||||
def _get_physical_blocks(
|
||||
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
|
||||
|
||||
# NOTE: Here, we assume that the physical blocks are only shared by
|
||||
# the sequences in the same group.
|
||||
request_id = seq_group.request_id
|
||||
blocks: Set[PhysicalTokenBlock] = set()
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
blocks.update(self.block_tables[seq.seq_id])
|
||||
# Cross-attention blocks
|
||||
if seq_group.is_encoder_decoder():
|
||||
blocks.update(self.cross_block_tables[request_id])
|
||||
return list(blocks)
|
||||
|
||||
def can_swap_in(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||
if seq_group.is_encoder_decoder():
|
||||
num_swapped_seqs += 1
|
||||
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
# NOTE: Conservatively, we assume that every sequence will allocate
|
||||
# at least one free block right after the swap-in.
|
||||
# NOTE: This should match the logic in can_append_slot().
|
||||
num_required_blocks = len(blocks) + num_swapped_seqs
|
||||
if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
|
||||
return AllocStatus.NEVER
|
||||
elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def _swap_block_table(
|
||||
self, block_table: BlockTable, src_allocator: BlockAllocatorBase,
|
||||
dest_allocator: BlockAllocatorBase,
|
||||
mapping: Dict[PhysicalTokenBlock,
|
||||
PhysicalTokenBlock]) -> BlockTable:
|
||||
new_block_table: BlockTable = BlockTable()
|
||||
|
||||
for from_block in block_table:
|
||||
if from_block in mapping:
|
||||
to_block = mapping[from_block]
|
||||
to_block.ref_count += 1
|
||||
else:
|
||||
to_block = dest_allocator.allocate(
|
||||
from_block.block_hash, from_block.num_hashed_tokens)
|
||||
mapping[from_block] = to_block
|
||||
new_block_table.append(to_block)
|
||||
# Free the source block swapped in to destination.
|
||||
src_allocator.free(from_block)
|
||||
|
||||
return new_block_table
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
|
||||
request_id = seq_group.request_id
|
||||
|
||||
# CPU block -> GPU block.
|
||||
# dict is efficient in lookup `if cpu_block in mapping`
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
self.block_tables[seq.seq_id] = \
|
||||
self._swap_block_table(self.block_tables[seq.seq_id],
|
||||
self.cpu_allocator, self.gpu_allocator,
|
||||
mapping)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
self.cross_block_tables[request_id] = \
|
||||
self._swap_block_table(self.cross_block_tables[request_id],
|
||||
self.cpu_allocator,
|
||||
self.gpu_allocator,
|
||||
mapping)
|
||||
|
||||
return [(cpu_block.block_number, gpu_block.block_number)
|
||||
for cpu_block, gpu_block in mapping.items()]
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
request_id = seq_group.request_id
|
||||
|
||||
# GPU block -> CPU block.
|
||||
# dict is efficient in lookup `if gpu_block in mapping`
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
self.block_tables[seq.seq_id] = \
|
||||
self._swap_block_table(self.block_tables[seq.seq_id],
|
||||
self.gpu_allocator, self.cpu_allocator,
|
||||
mapping)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
self.cross_block_tables[request_id] = \
|
||||
self._swap_block_table(self.cross_block_tables[request_id],
|
||||
self.gpu_allocator,
|
||||
self.cpu_allocator,
|
||||
mapping)
|
||||
|
||||
return [(cpu_block.block_number, gpu_block.block_number)
|
||||
for cpu_block, gpu_block in mapping.items()]
|
||||
|
||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||
# when using a sliding window, each seq will only use up
|
||||
# to `self.block_sliding_window` blocks. When freeing
|
||||
# the block table, we must make sure to not free blocks more
|
||||
# than once. If no sliding window is used, there is no block
|
||||
# reuse in the block table, so we must free all blocks.
|
||||
blocks_to_free = (block_table[-self.block_sliding_window:]
|
||||
if self.block_sliding_window is not None else
|
||||
block_table)
|
||||
for block in set(blocks_to_free):
|
||||
if block.device == Device.GPU:
|
||||
self.gpu_allocator.free(block)
|
||||
else:
|
||||
self.cpu_allocator.free(block)
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
# Already freed or haven't been scheduled yet.
|
||||
return
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
self._free_block_table(block_table)
|
||||
del self.block_tables[seq.seq_id]
|
||||
|
||||
def free_cross(self, seq_group: SequenceGroup) -> None:
|
||||
if seq_group.request_id not in self.cross_block_tables:
|
||||
# Already freed or hasn't ben scheduled yet.
|
||||
return
|
||||
block_table = self.cross_block_tables[seq_group.request_id]
|
||||
self._free_block_table(block_table)
|
||||
del self.cross_block_tables[seq_group.request_id]
|
||||
|
||||
def reset(self) -> None:
|
||||
# Free decoder block tables
|
||||
for block_table in self.block_tables.values():
|
||||
self._free_block_table(block_table)
|
||||
self.block_tables.clear()
|
||||
# Free cross-attention block tables
|
||||
for block_table in self.cross_block_tables.values():
|
||||
self._free_block_table(block_table)
|
||||
self.cross_block_tables.clear()
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
return self.block_tables[seq.seq_id].ids()
|
||||
|
||||
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
|
||||
block_table = self.cross_block_tables[seq_group.request_id]
|
||||
return [block.block_number for block in block_table]
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return self.gpu_allocator.get_num_free_blocks()
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.cpu_allocator.get_num_free_blocks()
|
||||
|
||||
def access_all_blocks_in_seq(
|
||||
self,
|
||||
seq: Sequence,
|
||||
access_time: float,
|
||||
) -> None:
|
||||
if self.enable_caching:
|
||||
# Update the last accessed time of all the blocks accessed
|
||||
# in this step.
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block in block_table:
|
||||
block.last_accessed = access_time
|
||||
|
||||
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return
|
||||
|
||||
# When chunked prefill is enabled, the computed full blocks
|
||||
# should be calculated based on the number of computed tokens.
|
||||
max_computed_tokens = (seq.data.get_num_computed_tokens() +
|
||||
token_chunk_size)
|
||||
computed_full_blocks = max_computed_tokens // self.block_size
|
||||
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if computed_full_blocks == 0:
|
||||
return
|
||||
for i in reversed(range(computed_full_blocks)):
|
||||
if block_table[i].computed:
|
||||
break
|
||||
block_table[i].computed = True
|
||||
|
||||
def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
# NOTE We exclude the last block to avoid the case where the entire
|
||||
# prompt is cached. This would cause erroneous behavior in model
|
||||
# runner.
|
||||
return [
|
||||
b.block_number
|
||||
for b in takewhile(lambda b: b.computed, block_table[:-1])
|
||||
]
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
"""Return the block ids that are common for a given sequence group.
|
||||
|
||||
Used in prefill (can skip prefill of some blocks).
|
||||
"""
|
||||
# Can return non-empty result only with prefix caching enabled.
|
||||
if not self.enable_caching:
|
||||
return []
|
||||
|
||||
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
|
||||
return commonprefix([ids for ids in ids_list if ids != []])
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
if self.enable_caching:
|
||||
for seq in seq_group.get_seqs():
|
||||
self.compute_full_blocks_in_seq(seq, token_chunk_size)
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
if device == Device.GPU:
|
||||
return self.gpu_allocator.get_prefix_cache_hit_rate()
|
||||
if device == Device.CPU:
|
||||
return self.cpu_allocator.get_prefix_cache_hit_rate()
|
||||
raise ValueError(f"Invalid device: {device}")
|
@ -28,13 +28,9 @@ class BlockSpaceManager(ABC):
|
||||
def get_block_space_manager_class(version: str):
|
||||
version = version.lower()
|
||||
|
||||
if version == "v1":
|
||||
from vllm.core.block_manager_v1 import BlockSpaceManagerV1
|
||||
return BlockSpaceManagerV1
|
||||
|
||||
if version == "v2":
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
return BlockSpaceManagerV2
|
||||
if version == "selfattn":
|
||||
from vllm.core.block_manager import SelfAttnBlockSpaceManager
|
||||
return SelfAttnBlockSpaceManager
|
||||
|
||||
if version == "placeholder":
|
||||
from vllm.core.placeholder_block_space_manager import (
|
||||
|
@ -312,9 +312,7 @@ class Scheduler:
|
||||
# LoRAs. This should be improved in the future.
|
||||
self.lora_config = lora_config
|
||||
|
||||
version = "v1"
|
||||
if self.scheduler_config.use_v2_block_manager:
|
||||
version = "v2"
|
||||
version = "selfattn"
|
||||
if (self.scheduler_config.embedding_mode
|
||||
or self.cache_config.is_attention_free):
|
||||
version = "placeholder"
|
||||
|
@ -373,12 +373,13 @@ class EngineArgs:
|
||||
action='store_true',
|
||||
help='Disables sliding window, '
|
||||
'capping to sliding window size')
|
||||
parser.add_argument(
|
||||
'--use-v2-block-manager',
|
||||
default=EngineArgs.use_v2_block_manager,
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
help='Use BlockSpaceMangerV2. By default this is set to True. '
|
||||
'Set to False to use BlockSpaceManagerV1')
|
||||
help='[DEPRECATED] block manager v1 has been '
|
||||
'removed and SelfAttnBlockSpaceManager (i.e. '
|
||||
'block manager v2) is now the default. '
|
||||
'Setting this flag to True or False'
|
||||
' has no effect on vLLM behavior.')
|
||||
parser.add_argument(
|
||||
'--num-lookahead-slots',
|
||||
type=int,
|
||||
@ -969,12 +970,6 @@ class EngineArgs:
|
||||
"in low performance due to small KV cache space. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
|
||||
if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
|
||||
self.use_v2_block_manager = True
|
||||
logger.warning(
|
||||
"Enabled BlockSpaceManagerV2 because it is "
|
||||
"required for multi-step (--num-scheduler-steps > 1)")
|
||||
|
||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
@ -990,7 +985,6 @@ class EngineArgs:
|
||||
speculative_disable_by_batch_size,
|
||||
speculative_max_model_len=self.speculative_max_model_len,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
use_v2_block_manager=self.use_v2_block_manager,
|
||||
disable_log_stats=self.disable_log_stats,
|
||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||
@ -1021,11 +1015,20 @@ class EngineArgs:
|
||||
if speculative_config is None \
|
||||
else speculative_config.num_lookahead_slots
|
||||
|
||||
if not self.use_v2_block_manager:
|
||||
logger.warning(
|
||||
"[DEPRECATED] Block manager v1 has been removed, "
|
||||
"and setting --use-v2-block-manager to True or False has "
|
||||
"no effect on vLLM behavior. Please remove "
|
||||
"--use-v2-block-manager in your engine argument. "
|
||||
"If your use case is not supported by "
|
||||
"SelfAttnBlockSpaceManager (i.e. block manager v2),"
|
||||
" please file an issue with detailed information.")
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_model_len=model_config.max_model_len,
|
||||
use_v2_block_manager=self.use_v2_block_manager,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
delay_factor=self.scheduler_delay_factor,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
@ -1081,13 +1084,6 @@ class EngineArgs:
|
||||
or "all" in detailed_trace_modules,
|
||||
)
|
||||
|
||||
if (model_config.get_sliding_window() is not None
|
||||
and scheduler_config.chunked_prefill_enabled
|
||||
and not scheduler_config.use_v2_block_manager):
|
||||
raise ValueError(
|
||||
"Chunked prefill is not supported with sliding window. "
|
||||
"Set --disable-sliding-window to disable sliding window.")
|
||||
|
||||
return EngineConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
|
@ -247,7 +247,7 @@ class LLMEngine:
|
||||
"enforce_eager=%s, kv_cache_dtype=%s, "
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
"decoding_config=%r, observability_config=%r, "
|
||||
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
||||
"seed=%d, served_model_name=%s, "
|
||||
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
|
||||
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
|
||||
"use_async_output_proc=%s, use_cached_outputs=%s, "
|
||||
@ -280,7 +280,6 @@ class LLMEngine:
|
||||
observability_config,
|
||||
model_config.seed,
|
||||
model_config.served_model_name,
|
||||
scheduler_config.use_v2_block_manager,
|
||||
scheduler_config.num_scheduler_steps,
|
||||
scheduler_config.chunked_prefill_enabled,
|
||||
scheduler_config.multi_step_stream_outputs,
|
||||
|
@ -64,7 +64,6 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_TRITON_AWQ: bool = False
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
|
||||
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||
VLLM_DISABLED_KERNELS: List[str] = []
|
||||
|
||||
@ -427,11 +426,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_SKIP_P2P_CHECK":
|
||||
lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1",
|
||||
|
||||
# If set, allowing the use of deprecated block manager V1
|
||||
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
|
||||
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
|
||||
) == "1",
|
||||
|
||||
# List of quantization kernels that should be disabled, used for testing
|
||||
# and performance comparisons. Currently only affects MPLinearKernel
|
||||
# selection
|
||||
|
@ -574,17 +574,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# paged attn. We can remove it if we make paged attn kernel
|
||||
# to properly handle slinding window attn.
|
||||
curr_sliding_window_block = self.sliding_window_blocks
|
||||
if self.scheduler_config.use_v2_block_manager:
|
||||
# number of elements in last block
|
||||
suff_len = inter_data.seq_lens[seq_idx] % self.block_size
|
||||
sliding_seq_len = min(
|
||||
inter_data.seq_lens[seq_idx],
|
||||
sliding_seq_len = min(inter_data.seq_lens[seq_idx],
|
||||
self.block_aligned_sliding_window + suff_len)
|
||||
if suff_len > 0:
|
||||
curr_sliding_window_block += 1
|
||||
else:
|
||||
sliding_seq_len = min(inter_data.seq_lens[seq_idx],
|
||||
self.sliding_window)
|
||||
|
||||
inter_data.curr_sliding_window_blocks[
|
||||
seq_idx] = curr_sliding_window_block
|
||||
|
Loading…
x
Reference in New Issue
Block a user