[Bugfix][ROCm] running new process using spawn method for rocm in tests. (#14810)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
vllmellm 2025-03-17 19:33:35 +08:00 committed by GitHub
parent 6eaf1e5c52
commit 2bb0e1a799
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 174 additions and 99 deletions

View File

@ -7,10 +7,10 @@ from vllm import LLM, SamplingParams
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
from ..utils import fork_new_process_for_each_test from ..utils import create_new_process_for_each_test
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_python_error(): def test_python_error():
""" """
Test if Python error occurs when there's low-level Test if Python error occurs when there's low-level
@ -36,7 +36,7 @@ def test_python_error():
allocator.wake_up() allocator.wake_up()
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_basic_cumem(): def test_basic_cumem():
# some tensors from default memory pool # some tensors from default memory pool
shape = (1024, 1024) shape = (1024, 1024)
@ -69,7 +69,7 @@ def test_basic_cumem():
assert torch.allclose(output, torch.ones_like(output) * 3) assert torch.allclose(output, torch.ones_like(output) * 3)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_cumem_with_cudagraph(): def test_cumem_with_cudagraph():
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool(): with allocator.use_memory_pool():
@ -114,7 +114,7 @@ def test_cumem_with_cudagraph():
assert torch.allclose(y, x + 1) assert torch.allclose(y, x + 1)
@fork_new_process_for_each_test @create_new_process_for_each_test()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, use_v1", "model, use_v1",
[ [

View File

@ -12,7 +12,7 @@ from vllm import LLM, SamplingParams
from vllm.config import CompilationLevel from vllm.config import CompilationLevel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import fork_new_process_for_each_test from ..utils import create_new_process_for_each_test
@pytest.fixture(params=None, name="model_info") @pytest.fixture(params=None, name="model_info")
@ -78,7 +78,7 @@ def models_list_fixture(request):
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
) )
@pytest.mark.parametrize("model_info", "", indirect=True) @pytest.mark.parametrize("model_info", "", indirect=True)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_full_graph( def test_full_graph(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
model_info: tuple[str, dict[str, Any]], model_info: tuple[str, dict[str, Any]],

View File

@ -8,7 +8,7 @@ import pytest
from vllm.config import TaskOption from vllm.config import TaskOption
from vllm.logger import init_logger from vllm.logger import init_logger
from ..utils import compare_two_settings, fork_new_process_for_each_test from ..utils import compare_two_settings, create_new_process_for_each_test
logger = init_logger("test_expert_parallel") logger = init_logger("test_expert_parallel")
@ -209,7 +209,7 @@ def _compare_tp(
for params in settings.iter_params(model_name) for params in settings.iter_params(model_name)
], ],
) )
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_ep( def test_ep(
model_name: str, model_name: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,

View File

@ -17,7 +17,7 @@ from vllm.config import TaskOption
from vllm.logger import init_logger from vllm.logger import init_logger
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, fork_new_process_for_each_test from ..utils import compare_two_settings, create_new_process_for_each_test
logger = init_logger("test_pipeline_parallel") logger = init_logger("test_pipeline_parallel")
@ -402,7 +402,7 @@ def _compare_tp(
for params in settings.iter_params(model_id) if model_id in TEST_MODELS for params in settings.iter_params(model_id) if model_id in TEST_MODELS
], ],
) )
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_tp_language_generation( def test_tp_language_generation(
model_id: str, model_id: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,
@ -431,7 +431,7 @@ def test_tp_language_generation(
for params in settings.iter_params(model_id) if model_id in TEST_MODELS for params in settings.iter_params(model_id) if model_id in TEST_MODELS
], ],
) )
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_tp_language_embedding( def test_tp_language_embedding(
model_id: str, model_id: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,
@ -460,7 +460,7 @@ def test_tp_language_embedding(
for params in settings.iter_params(model_id) if model_id in TEST_MODELS for params in settings.iter_params(model_id) if model_id in TEST_MODELS
], ],
) )
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_tp_multimodal_generation( def test_tp_multimodal_generation(
model_id: str, model_id: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import pytest import pytest
from ..utils import compare_two_settings, fork_new_process_for_each_test from ..utils import compare_two_settings, create_new_process_for_each_test
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import LiteralString from typing_extensions import LiteralString
@ -18,7 +18,7 @@ if TYPE_CHECKING:
"FLASH_ATTN", "FLASH_ATTN",
"FLASHINFER", "FLASHINFER",
]) ])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_pp_cudagraph( def test_pp_cudagraph(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
PP_SIZE: int, PP_SIZE: int,

View File

@ -4,12 +4,12 @@ import pytest
from vllm import LLM from vllm import LLM
from ...utils import fork_new_process_for_each_test from ...utils import create_new_process_for_each_test
@pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("backend", ["mp", "ray"]) @pytest.mark.parametrize("backend", ["mp", "ray"])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_collective_rpc(tp_size, backend): def test_collective_rpc(tp_size, backend):
if tp_size == 1 and backend == "ray": if tp_size == 1 and backend == "ray":
pytest.skip("Skip duplicate test case") pytest.skip("Skip duplicate test case")

View File

@ -3,10 +3,9 @@
import pytest import pytest
import vllm import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test from ..utils import create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "THUDM/chatglm3-6b" MODEL_PATH = "THUDM/chatglm3-6b"
@ -55,7 +54,7 @@ def v1(run_with_both_engines_lora):
pass pass
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_chatglm3_lora(chatglm3_lora_files): def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
max_model_len=1024, max_model_len=1024,
@ -75,7 +74,7 @@ def test_chatglm3_lora(chatglm3_lora_files):
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_chatglm3_lora_tp4(chatglm3_lora_files): def test_chatglm3_lora_tp4(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
max_model_len=1024, max_model_len=1024,
@ -96,7 +95,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
max_model_len=1024, max_model_len=1024,

View File

@ -4,10 +4,9 @@ import pytest
import ray import ray
import vllm import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test from ..utils import create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "meta-llama/Llama-2-7b-hf" MODEL_PATH = "meta-llama/Llama-2-7b-hf"
@ -82,7 +81,7 @@ def v1(run_with_both_engines_lora):
# V1 Test: Failing due to numerics on V1. # V1 Test: Failing due to numerics on V1.
@pytest.mark.skip_v1 @pytest.mark.skip_v1
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_llama_lora(sql_lora_files): def test_llama_lora(sql_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
@ -97,7 +96,7 @@ def test_llama_lora(sql_lora_files):
# Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks # Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks
# used by the engine yet. # used by the engine yet.
@pytest.mark.skip_v1 @pytest.mark.skip_v1
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_llama_lora_warmup(sql_lora_files): def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and """Test that the LLM initialization works with a warmup LORA path and
is more conservative""" is more conservative"""
@ -128,7 +127,7 @@ def test_llama_lora_warmup(sql_lora_files):
# V1 Test: Failing due to numerics on V1. # V1 Test: Failing due to numerics on V1.
@pytest.mark.skip_v1 @pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_llama_lora_tp4(sql_lora_files): def test_llama_lora_tp4(sql_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
@ -143,7 +142,7 @@ def test_llama_lora_tp4(sql_lora_files):
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
@ -159,7 +158,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files): def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
llm = vllm.LLM( llm = vllm.LLM(

View File

@ -3,11 +3,12 @@
import pytest import pytest
import vllm import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE = ( PROMPT_TEMPLATE = (
@ -57,7 +58,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
@pytest.mark.xfail( @pytest.mark.xfail(
current_platform.is_rocm(), current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm") reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_minicpmv_lora(minicpmv_lora_files): def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
@ -80,7 +81,7 @@ def test_minicpmv_lora(minicpmv_lora_files):
@pytest.mark.xfail( @pytest.mark.xfail(
current_platform.is_rocm(), current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm") reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
@ -101,7 +102,7 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
@pytest.mark.xfail( @pytest.mark.xfail(
current_platform.is_rocm(), current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm") reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,

View File

@ -3,10 +3,9 @@
import pytest import pytest
import vllm import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test from ..utils import create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "ArthurZ/ilama-3.2-1B" MODEL_PATH = "ArthurZ/ilama-3.2-1B"
@ -56,7 +55,7 @@ def v1(run_with_both_engines_lora):
@pytest.mark.skip_v1 @pytest.mark.skip_v1
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_ilama_lora(ilama_lora_files): def test_ilama_lora(ilama_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
max_model_len=1024, max_model_len=1024,
@ -77,7 +76,7 @@ def test_ilama_lora(ilama_lora_files):
@pytest.mark.skip_v1 @pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_ilama_lora_tp4(ilama_lora_files): def test_ilama_lora_tp4(ilama_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
max_model_len=1024, max_model_len=1024,
@ -99,7 +98,7 @@ def test_ilama_lora_tp4(ilama_lora_files):
@pytest.mark.skip_v1 @pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files): def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
max_model_len=1024, max_model_len=1024,

View File

@ -17,7 +17,7 @@ from vllm.utils import identity
from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets, from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets,
_VideoAssets) _VideoAssets)
from ....utils import (fork_new_process_for_each_test, large_gpu_mark, from ....utils import (create_new_process_for_each_test, large_gpu_mark,
multi_gpu_marks) multi_gpu_marks)
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
from .vlm_utils import custom_inputs, model_utils, runners from .vlm_utils import custom_inputs, model_utils, runners
@ -592,7 +592,7 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2)
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.IMAGE, test_type=VLMTestType.IMAGE,
fork_new_process_for_each_test=False, create_new_process_for_each_test=False,
)) ))
def test_single_image_models(tmp_path: PosixPath, model_type: str, def test_single_image_models(tmp_path: PosixPath, model_type: str,
test_case: ExpandableVLMTestArgs, test_case: ExpandableVLMTestArgs,
@ -617,7 +617,7 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str,
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.MULTI_IMAGE, test_type=VLMTestType.MULTI_IMAGE,
fork_new_process_for_each_test=False, create_new_process_for_each_test=False,
)) ))
def test_multi_image_models(tmp_path: PosixPath, model_type: str, def test_multi_image_models(tmp_path: PosixPath, model_type: str,
test_case: ExpandableVLMTestArgs, test_case: ExpandableVLMTestArgs,
@ -642,7 +642,7 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str,
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.EMBEDDING, test_type=VLMTestType.EMBEDDING,
fork_new_process_for_each_test=False, create_new_process_for_each_test=False,
)) ))
def test_image_embedding_models(model_type: str, def test_image_embedding_models(model_type: str,
test_case: ExpandableVLMTestArgs, test_case: ExpandableVLMTestArgs,
@ -666,7 +666,7 @@ def test_image_embedding_models(model_type: str,
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.VIDEO, test_type=VLMTestType.VIDEO,
fork_new_process_for_each_test=False, create_new_process_for_each_test=False,
)) ))
def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
@ -688,7 +688,7 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs,
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.CUSTOM_INPUTS, test_type=VLMTestType.CUSTOM_INPUTS,
fork_new_process_for_each_test=False, create_new_process_for_each_test=False,
)) ))
def test_custom_inputs_models( def test_custom_inputs_models(
model_type: str, model_type: str,
@ -714,9 +714,9 @@ def test_custom_inputs_models(
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.IMAGE, test_type=VLMTestType.IMAGE,
fork_new_process_for_each_test=True, create_new_process_for_each_test=True,
)) ))
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str,
test_case: ExpandableVLMTestArgs, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
@ -740,9 +740,9 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str,
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.MULTI_IMAGE, test_type=VLMTestType.MULTI_IMAGE,
fork_new_process_for_each_test=True, create_new_process_for_each_test=True,
)) ))
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str,
test_case: ExpandableVLMTestArgs, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
@ -766,9 +766,9 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str,
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.EMBEDDING, test_type=VLMTestType.EMBEDDING,
fork_new_process_for_each_test=True, create_new_process_for_each_test=True,
)) ))
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_image_embedding_models_heavy(model_type: str, def test_image_embedding_models_heavy(model_type: str,
test_case: ExpandableVLMTestArgs, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
@ -791,7 +791,7 @@ def test_image_embedding_models_heavy(model_type: str,
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.VIDEO, test_type=VLMTestType.VIDEO,
fork_new_process_for_each_test=True, create_new_process_for_each_test=True,
)) ))
def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
@ -814,9 +814,9 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
get_parametrized_options( get_parametrized_options(
VLM_TEST_SETTINGS, VLM_TEST_SETTINGS,
test_type=VLMTestType.CUSTOM_INPUTS, test_type=VLMTestType.CUSTOM_INPUTS,
fork_new_process_for_each_test=True, create_new_process_for_each_test=True,
)) ))
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_custom_inputs_models_heavy( def test_custom_inputs_models_heavy(
model_type: str, model_type: str,
test_case: ExpandableVLMTestArgs, test_case: ExpandableVLMTestArgs,

View File

@ -13,9 +13,9 @@ from .types import (EMBEDDING_SIZE_FACTORS, ExpandableVLMTestArgs,
ImageSizeWrapper, SizeType, VLMTestInfo, VLMTestType) ImageSizeWrapper, SizeType, VLMTestInfo, VLMTestType)
def get_filtered_test_settings(test_settings: dict[str, VLMTestInfo], def get_filtered_test_settings(
test_type: VLMTestType, test_settings: dict[str, VLMTestInfo], test_type: VLMTestType,
fork_per_test: bool) -> dict[str, VLMTestInfo]: new_proc_per_test: bool) -> dict[str, VLMTestInfo]:
"""Given the dict of potential test settings to run, return a subdict """Given the dict of potential test settings to run, return a subdict
of tests who have the current test type enabled with the matching val for of tests who have the current test type enabled with the matching val for
fork_per_test. fork_per_test.
@ -43,7 +43,7 @@ def get_filtered_test_settings(test_settings: dict[str, VLMTestInfo],
# Everything looks okay; keep if this is has correct proc handling # Everything looks okay; keep if this is has correct proc handling
if (test_info.distributed_executor_backend if (test_info.distributed_executor_backend
is not None) == fork_per_test: is not None) == new_proc_per_test:
matching_tests[test_name] = test_info matching_tests[test_name] = test_info
return matching_tests return matching_tests
@ -51,14 +51,14 @@ def get_filtered_test_settings(test_settings: dict[str, VLMTestInfo],
def get_parametrized_options(test_settings: dict[str, VLMTestInfo], def get_parametrized_options(test_settings: dict[str, VLMTestInfo],
test_type: VLMTestType, test_type: VLMTestType,
fork_new_process_for_each_test: bool): create_new_process_for_each_test: bool):
"""Converts all of our VLMTestInfo into an expanded list of parameters. """Converts all of our VLMTestInfo into an expanded list of parameters.
This is similar to nesting pytest parametrize calls, but done directly This is similar to nesting pytest parametrize calls, but done directly
through an itertools product so that each test can set things like through an itertools product so that each test can set things like
size factors etc, while still running in isolated test cases. size factors etc, while still running in isolated test cases.
""" """
matching_tests = get_filtered_test_settings( matching_tests = get_filtered_test_settings(
test_settings, test_type, fork_new_process_for_each_test) test_settings, test_type, create_new_process_for_each_test)
# Ensure that something is wrapped as an iterable it's not already # Ensure that something is wrapped as an iterable it's not already
ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e, ) ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e, )

View File

@ -10,7 +10,7 @@ import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from ....utils import fork_new_process_for_each_test, multi_gpu_test from ....utils import create_new_process_for_each_test, multi_gpu_test
PROMPTS = [ PROMPTS = [
{ {
@ -119,7 +119,7 @@ def run_test(
assert output.outputs[0].text == expected assert output.outputs[0].text == expected
@fork_new_process_for_each_test @create_new_process_for_each_test()
@pytest.mark.core_model @pytest.mark.core_model
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"]) "model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"])

View File

@ -5,10 +5,10 @@ import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from ..utils import fork_new_process_for_each_test from ..utils import create_new_process_for_each_test
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_plugin( def test_plugin(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
dummy_opt_path: str, dummy_opt_path: str,
@ -24,7 +24,7 @@ def test_plugin(
assert (error_msg in str(excinfo.value)) assert (error_msg in str(excinfo.value))
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_oot_registration_text_generation( def test_oot_registration_text_generation(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
dummy_opt_path: str, dummy_opt_path: str,
@ -44,7 +44,7 @@ def test_oot_registration_text_generation(
assert rest == "" assert rest == ""
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_oot_registration_embedding( def test_oot_registration_embedding(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
dummy_gemma2_embedding_path: str, dummy_gemma2_embedding_path: str,
@ -62,7 +62,7 @@ def test_oot_registration_embedding(
image = ImageAsset("cherry_blossom").pil_image.convert("RGB") image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_oot_registration_multimodal( def test_oot_registration_multimodal(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
dummy_llava_path: str, dummy_llava_path: str,

View File

@ -17,7 +17,7 @@ from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
ModelRegistry) ModelRegistry)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import fork_new_process_for_each_test from ..utils import create_new_process_for_each_test
from .registry import HF_EXAMPLE_MODELS from .registry import HF_EXAMPLE_MODELS
@ -45,7 +45,7 @@ def test_registry_imports(model_arch):
assert supports_multimodal(model_cls) assert supports_multimodal(model_cls)
@fork_new_process_for_each_test @create_new_process_for_each_test()
@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [ @pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [
("LlamaForCausalLM", False, False, False), ("LlamaForCausalLM", False, False, False),
("MllamaForConditionalGeneration", True, False, False), ("MllamaForConditionalGeneration", True, False, False),
@ -70,7 +70,7 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
stacklevel=2) stacklevel=2)
@fork_new_process_for_each_test @create_new_process_for_each_test()
@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ @pytest.mark.parametrize("model_arch,is_pp,init_cuda", [
("MLPSpeculatorPreTrainedModel", False, False), ("MLPSpeculatorPreTrainedModel", False, False),
("DeepseekV2ForCausalLM", True, False), ("DeepseekV2ForCausalLM", True, False),

View File

@ -10,7 +10,8 @@ import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from tests.utils import compare_two_settings, fork_new_process_for_each_test
from ..utils import compare_two_settings, create_new_process_for_each_test
models_4bit_to_test = [ models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"), ("facebook/opt-125m", "quantize opt model inflight"),
@ -32,7 +33,7 @@ models_pre_quant_8bit_to_test = [
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.') reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test) @pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None: model_name, description) -> None:
@ -45,7 +46,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
reason='bitsandbytes is not supported on this GPU type.') reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", @pytest.mark.parametrize("model_name, description",
models_pre_qaunt_4bit_to_test) models_pre_qaunt_4bit_to_test)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None: model_name, description) -> None:
@ -57,7 +58,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
reason='bitsandbytes is not supported on this GPU type.') reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", @pytest.mark.parametrize("model_name, description",
models_pre_quant_8bit_to_test) models_pre_quant_8bit_to_test)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None: model_name, description) -> None:
@ -70,7 +71,7 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.') reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test) @pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None: model_name, description) -> None:
@ -88,7 +89,7 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.') reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test) @pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_load_pp_4bit_bnb_model(model_name, description) -> None: def test_load_pp_4bit_bnb_model(model_name, description) -> None:
common_args = [ common_args = [
"--disable-log-stats", "--disable-log-stats",

View File

@ -42,7 +42,7 @@ from transformers import AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
from ...utils import fork_new_process_for_each_test from ...utils import create_new_process_for_each_test
from .conftest import (get_output_from_llm_generator, from .conftest import (get_output_from_llm_generator,
run_equality_correctness_test) run_equality_correctness_test)
@ -82,7 +82,7 @@ from .conftest import (get_output_from_llm_generator,
@pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_spec_decode_e2e_with_detokenization(test_llm_generator, def test_spec_decode_e2e_with_detokenization(test_llm_generator,
batch_size: int): batch_size: int):
"""Run generation with speculative decoding on a batch. Verify the engine """Run generation with speculative decoding on a batch. Verify the engine
@ -170,7 +170,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
]) ])
@pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
@ -244,7 +244,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
]) ])
@pytest.mark.parametrize("batch_size", [64]) @pytest.mark.parametrize("batch_size", [64])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
@ -300,7 +300,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
]) ])
@pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
@ -356,7 +356,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
256, 256,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_real_model_bs1( def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
@ -411,7 +411,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
64, 64,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
@ -469,7 +469,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
]) ])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_with_preemption( def test_spec_decode_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
@ -534,7 +534,7 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
@ -594,7 +594,7 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
64, 64,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_skip_speculation(vllm_runner, common_llm_kwargs, def test_skip_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int, test_llm_kwargs, batch_size: int, output_len: int,
@ -644,7 +644,7 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10]) @pytest.mark.parametrize("output_len", [10])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_disable_speculation(vllm_runner, common_llm_kwargs, def test_disable_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int, test_llm_kwargs, batch_size: int, output_len: int,
@ -697,7 +697,7 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int): output_len: int, seed: int):
@ -752,7 +752,7 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs, def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,

View File

@ -16,7 +16,7 @@ from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
deprecate_kwargs, get_open_port, memory_profiling, deprecate_kwargs, get_open_port, memory_profiling,
merge_async_iterators, supports_kw, swap_dict_values) merge_async_iterators, supports_kw, swap_dict_values)
from .utils import error_on_warning, fork_new_process_for_each_test from .utils import create_new_process_for_each_test, error_on_warning
@pytest.mark.asyncio @pytest.mark.asyncio
@ -276,7 +276,7 @@ def test_supports_kw(callable,kw_name,requires_kw_only,
) == is_supported ) == is_supported
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_memory_profiling(): def test_memory_profiling():
# Fake out some model loading + inference memory usage to test profiling # Fake out some model loading + inference memory usage to test profiling
# Memory used by other processes will show up as cuda usage outside of torch # Memory used by other processes will show up as cuda usage outside of torch

View File

@ -7,12 +7,14 @@ import os
import signal import signal
import subprocess import subprocess
import sys import sys
import tempfile
import time import time
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager, suppress
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Literal, Optional, Union
import cloudpickle
import openai import openai
import pytest import pytest
import requests import requests
@ -703,6 +705,78 @@ def fork_new_process_for_each_test(
return wrapper return wrapper
def spawn_new_process_for_each_test(
f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to spawn a new process for each test function.
"""
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Check if we're already in a subprocess
if os.environ.get('RUNNING_IN_SUBPROCESS') == '1':
# If we are, just run the function directly
return f(*args, **kwargs)
import torch.multiprocessing as mp
with suppress(RuntimeError):
mp.set_start_method('spawn')
# Get the module
module_name = f.__module__
# Create a process with environment variable set
env = os.environ.copy()
env['RUNNING_IN_SUBPROCESS'] = '1'
with tempfile.TemporaryDirectory() as tempdir:
output_filepath = os.path.join(tempdir, "new_process.tmp")
# `cloudpickle` allows pickling complex functions directly
input_bytes = cloudpickle.dumps((f, output_filepath))
cmd = [sys.executable, "-m", f"{module_name}"]
returned = subprocess.run(cmd,
input=input_bytes,
capture_output=True,
env=env)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error raised in subprocess:\n"
f"{returned.stderr.decode()}") from e
return wrapper
def create_new_process_for_each_test(
method: Optional[Literal["spawn", "fork"]] = None
) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
"""Creates a decorator that runs each test function in a new process.
Args:
method: The process creation method. Can be either "spawn" or "fork".
If not specified,
it defaults to "spawn" on ROCm platforms and "fork" otherwise.
Returns:
A decorator to run test functions in separate processes.
"""
if method is None:
method = "spawn" if current_platform.is_rocm() else "fork"
assert method in ["spawn",
"fork"], "Method must be either 'spawn' or 'fork'"
if method == "fork":
return fork_new_process_for_each_test
return spawn_new_process_for_each_test
def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator: def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
""" """
Get a pytest mark, which skips the test if the GPU doesn't meet Get a pytest mark, which skips the test if the GPU doesn't meet
@ -762,7 +836,7 @@ def multi_gpu_test(*, num_gpus: int):
marks = multi_gpu_marks(num_gpus=num_gpus) marks = multi_gpu_marks(num_gpus=num_gpus)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
func = fork_new_process_for_each_test(f) func = create_new_process_for_each_test()(f)
for mark in reversed(marks): for mark in reversed(marks):
func = mark(func) func = mark(func)

View File

@ -9,7 +9,6 @@ from concurrent.futures import Future
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -19,6 +18,8 @@ from vllm.v1.executor.abstract import Executor, UniProcExecutor
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from ...utils import create_new_process_for_each_test
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True) allow_module_level=True)
@ -44,7 +45,7 @@ def make_request() -> EngineCoreRequest:
) )
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_engine_core(monkeypatch: pytest.MonkeyPatch): def test_engine_core(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
@ -158,7 +159,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
""" """
A basic end-to-end test to verify that the engine functions correctly A basic end-to-end test to verify that the engine functions correctly
@ -208,7 +209,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
_check_engine_state() _check_engine_state()
@fork_new_process_for_each_test @create_new_process_for_each_test()
def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
""" """
Test that the engine can handle multiple concurrent batches. Test that the engine can handle multiple concurrent batches.

View File

@ -8,7 +8,6 @@ from typing import Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -19,6 +18,8 @@ from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient) SyncMPClient)
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from ...utils import create_new_process_for_each_test
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True) allow_module_level=True)
@ -88,7 +89,7 @@ def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
return msg return msg
@fork_new_process_for_each_test @create_new_process_for_each_test()
@pytest.mark.parametrize("multiprocessing_mode", [True, False]) @pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool): multiprocessing_mode: bool):