[Core] Add an environment variable which needs to be set explicitly to allow BlockSpaceManagerV1 (#9149)
This commit is contained in:
parent
a64e7b9407
commit
f3a507f1d3
@ -77,8 +77,8 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/basic_correctness/test_chunked_prefill
|
- tests/basic_correctness/test_chunked_prefill
|
||||||
commands:
|
commands:
|
||||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
- VLLM_ATTENTION_BACKEND=XFORMERS VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
|
||||||
- label: Core Test # 10min
|
- label: Core Test # 10min
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
@ -88,7 +88,11 @@ steps:
|
|||||||
- vllm/distributed
|
- vllm/distributed
|
||||||
- tests/core
|
- tests/core
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s core
|
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core/test_scheduler.py
|
||||||
|
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/test_chunked_prefill_scheduler.py
|
||||||
|
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness.py
|
||||||
|
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness_sliding_window.py
|
||||||
|
- pytest -v -s core --ignore=core/block/e2e/test_correctness.py --ignore=core/test_scheduler.py --ignore=core/test_chunked_prefill_scheduler.py --ignore=core/block/e2e/test_correctness.py --ignore=core/block/e2e/test_correctness_sliding_window.py
|
||||||
|
|
||||||
- label: Entrypoints Test # 40min
|
- label: Entrypoints Test # 40min
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
@ -185,7 +189,8 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/prefix_caching
|
- tests/prefix_caching
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s prefix_caching
|
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s prefix_caching/test_prefix_caching.py
|
||||||
|
- pytest -v -s prefix_caching --ignore=prefix_caching/test_prefix_caching.py
|
||||||
|
|
||||||
- label: Samplers Test # 36min
|
- label: Samplers Test # 36min
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -209,7 +214,8 @@ steps:
|
|||||||
- tests/spec_decode
|
- tests/spec_decode
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
- VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s spec_decode/e2e/test_compatibility.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_compatibility.py
|
||||||
|
|
||||||
- label: LoRA Test %N # 15min each
|
- label: LoRA Test %N # 15min each
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
@ -391,7 +397,7 @@ steps:
|
|||||||
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
|
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
|
||||||
- pytest -v -s ./compile/test_wrapper.py
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
||||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||||
# Avoid importing model tests that cause CUDA reinitialization error
|
# Avoid importing model tests that cause CUDA reinitialization error
|
||||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus
|
- pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus
|
||||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
|
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
|
||||||
|
@ -221,7 +221,9 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--enable-prefix-caching",
|
parser.add_argument("--enable-prefix-caching",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="Enable automatic prefix caching")
|
help="Enable automatic prefix caching")
|
||||||
parser.add_argument('--use-v2-block-manager', action='store_true')
|
parser.add_argument('--use-v2-block-manager',
|
||||||
|
action='store_true',
|
||||||
|
default=EngineArgs.use_v2_block_manager)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ray-workers-use-nsight",
|
"--ray-workers-use-nsight",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
|
@ -33,6 +33,7 @@ from typing import List, Optional, Tuple
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -177,6 +178,7 @@ if __name__ == "__main__":
|
|||||||
help='enable prefix caching')
|
help='enable prefix caching')
|
||||||
parser.add_argument('--use-v2-block-manager',
|
parser.add_argument('--use-v2-block-manager',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
|
default=EngineArgs.use_v2_block_manager,
|
||||||
help='Use BlockSpaceMangerV2')
|
help='Use BlockSpaceMangerV2')
|
||||||
parser.add_argument('--num-prompts',
|
parser.add_argument('--num-prompts',
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -473,6 +473,7 @@ if __name__ == "__main__":
|
|||||||
help="Maximum number of forward steps per scheduler call.")
|
help="Maximum number of forward steps per scheduler call.")
|
||||||
parser.add_argument("--use-v2-block-manager",
|
parser.add_argument("--use-v2-block-manager",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
|
default=EngineArgs.use_v2_block_manager,
|
||||||
help="Enable block manager v2.")
|
help="Enable block manager v2.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-prefix-caching",
|
"--enable-prefix-caching",
|
||||||
|
@ -12,7 +12,7 @@ from contextlib import nullcontext
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..models.utils import check_logprobs_close, check_outputs_equal
|
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import check_deprecated_block_manager_usage, multi_gpu_test
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
@ -20,6 +20,12 @@ MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def check_deprecated_block_manager():
|
||||||
|
check_deprecated_block_manager_usage(
|
||||||
|
'tests/basic_correctness/test_chunked_prefill.py')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
@ -2,11 +2,18 @@ from itertools import cycle
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.utils import check_deprecated_block_manager_usage
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
from .conftest import get_token_ids_from_llm_generator
|
from .conftest import get_token_ids_from_llm_generator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def check_deprecated_block_manager():
|
||||||
|
check_deprecated_block_manager_usage(
|
||||||
|
'tests/core/block/e2e/test_correctness.py')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
@ -3,6 +3,7 @@ from typing import List
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.utils import check_deprecated_block_manager_usage
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
from .conftest import get_text_from_llm_generator
|
from .conftest import get_text_from_llm_generator
|
||||||
@ -12,6 +13,12 @@ MODEL = "bigcode/starcoder2-3b"
|
|||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def check_deprecated_block_manager():
|
||||||
|
check_deprecated_block_manager_usage(
|
||||||
|
'tests/core/block/e2e/test_correctness_sliding_window.py')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
@ -8,6 +8,7 @@ from vllm.core.interfaces import AllocStatus
|
|||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler
|
||||||
from vllm.sequence import Logprob, SequenceGroup
|
from vllm.sequence import Logprob, SequenceGroup
|
||||||
|
|
||||||
|
from ..utils import check_deprecated_block_manager_usage
|
||||||
from .utils import create_dummy_prompt
|
from .utils import create_dummy_prompt
|
||||||
|
|
||||||
|
|
||||||
@ -27,6 +28,12 @@ def schedule_and_update_computed_tokens(scheduler):
|
|||||||
return metas, out
|
return metas, out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def check_deprecated_block_manager():
|
||||||
|
check_deprecated_block_manager_usage(
|
||||||
|
'tests/core/test_chunked_prefill_scheduler.py')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||||
def test_simple(use_v2_block_manager: bool):
|
def test_simple(use_v2_block_manager: bool):
|
||||||
"""Verify basic scheduling works."""
|
"""Verify basic scheduling works."""
|
||||||
|
@ -12,11 +12,18 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||||
|
|
||||||
|
from ..utils import check_deprecated_block_manager_usage
|
||||||
from .utils import (append_new_token, append_new_token_seq_group,
|
from .utils import (append_new_token, append_new_token_seq_group,
|
||||||
create_dummy_prompt, get_sequence_groups,
|
create_dummy_prompt, get_sequence_groups,
|
||||||
schedule_and_update_computed_tokens)
|
schedule_and_update_computed_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def check_deprecated_block_manager():
|
||||||
|
check_deprecated_block_manager_usage(
|
||||||
|
"tests/core/test_chunked_prefill_scheduler.py")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
@pytest.mark.parametrize('use_v2_block_manager', [True, False])
|
||||||
def test_scheduler_add_seq_group(use_v2_block_manager: bool):
|
def test_scheduler_add_seq_group(use_v2_block_manager: bool):
|
||||||
block_size = 4
|
block_size = 4
|
||||||
|
@ -7,6 +7,7 @@ from typing import List
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.kernels.utils import override_backend_env_variable
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
|
from tests.utils import check_deprecated_block_manager_usage
|
||||||
from vllm.block import PhysicalTokenBlock
|
from vllm.block import PhysicalTokenBlock
|
||||||
from vllm.core.block_manager_v1 import CachedBlockAllocator
|
from vllm.core.block_manager_v1 import CachedBlockAllocator
|
||||||
from vllm.utils import Device
|
from vllm.utils import Device
|
||||||
@ -18,6 +19,12 @@ MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def check_deprecated_block_manager():
|
||||||
|
check_deprecated_block_manager_usage(
|
||||||
|
'tests/prefix_caching/test_prefix_caching.py')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("block_size", [16])
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
@pytest.mark.parametrize("num_blocks", [16])
|
@pytest.mark.parametrize("num_blocks", [16])
|
||||||
def test_block_allocator(
|
def test_block_allocator(
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.utils import check_deprecated_block_manager_usage
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
from .conftest import get_output_from_llm_generator
|
from .conftest import get_output_from_llm_generator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def check_deprecated_block_manager():
|
||||||
|
check_deprecated_block_manager_usage(
|
||||||
|
'tests/spec_decode/e2e/test_compatibility.py')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
@ -678,3 +678,12 @@ def get_client_text_logprob_generations(
|
|||||||
return [(text_generations, text,
|
return [(text_generations, text,
|
||||||
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
||||||
for completion in completions for x in completion.choices]
|
for completion in completions for x in completion.choices]
|
||||||
|
|
||||||
|
|
||||||
|
def check_deprecated_block_manager_usage(test_name: str):
|
||||||
|
assert envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1 is True, (
|
||||||
|
f"To allow the use of deprecated BlockSpaceManagerV1, set the "
|
||||||
|
f"environment variable VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. "
|
||||||
|
f"You can run the tests with: "
|
||||||
|
f"`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest {test_name}`" #noqa
|
||||||
|
)
|
||||||
|
@ -1037,6 +1037,18 @@ class SchedulerConfig:
|
|||||||
f"({self.num_scheduler_steps}) must be greater than or "
|
f"({self.num_scheduler_steps}) must be greater than or "
|
||||||
"equal to 1.")
|
"equal to 1.")
|
||||||
|
|
||||||
|
if (not self.use_v2_block_manager \
|
||||||
|
and not envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1):
|
||||||
|
raise ValueError(
|
||||||
|
"The use of BlockSpaceManagerV1 is deprecated and will "
|
||||||
|
"be removed in a future release. Please switch to "
|
||||||
|
"BlockSpaceManagerV2 by setting --use-v2-block-manager to "
|
||||||
|
"True. If you wish to suppress this error temporarily, "
|
||||||
|
"you can set the environment variable "
|
||||||
|
"`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. If your use "
|
||||||
|
"case is not supported in BlockSpaceManagerV2, please "
|
||||||
|
"file an issue with detailed information.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multi_step(self) -> bool:
|
def is_multi_step(self) -> bool:
|
||||||
return self.num_scheduler_steps > 1
|
return self.num_scheduler_steps > 1
|
||||||
|
@ -64,6 +64,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_USE_TRITON_AWQ: bool = False
|
VLLM_USE_TRITON_AWQ: bool = False
|
||||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||||
VLLM_SKIP_P2P_CHECK: bool = False
|
VLLM_SKIP_P2P_CHECK: bool = False
|
||||||
|
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -434,6 +435,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# and trust the driver's peer-to-peer capability report.
|
# and trust the driver's peer-to-peer capability report.
|
||||||
"VLLM_SKIP_P2P_CHECK":
|
"VLLM_SKIP_P2P_CHECK":
|
||||||
lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1",
|
lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1",
|
||||||
|
|
||||||
|
# If set, allowing the use of deprecated block manager V1
|
||||||
|
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
|
||||||
|
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
|
||||||
|
) == "1",
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
Loading…
x
Reference in New Issue
Block a user