[MISC] Consolidate cleanup() and refactor offline_inference_with_prefix.py (#9510)

This commit is contained in:
Cody Yu 2024-10-18 14:30:55 -07:00 committed by GitHub
parent 9bb10a7d27
commit d11bf435a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 84 additions and 105 deletions

View File

@ -1,4 +1,5 @@
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
# NOTE: This is just a running example. For benchmarking purpose, # NOTE: This is just a running example. For benchmarking purpose,
# please see benchmarks/benchmark_prefix_caching.py # please see benchmarks/benchmark_prefix_caching.py
@ -28,14 +29,9 @@ generating_prompts = [prefix + prompt for prompt in prompts]
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0) sampling_params = SamplingParams(temperature=0.0)
# Create an LLM. # Create an LLM without prefix caching as a baseline.
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3) regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)
# The second LLM needs to request a higher gpu_memory_utilization because
# the first LLM has already allocated a full 30% of the gpu memory.
prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True,
gpu_memory_utilization=0.6)
print("Results without `enable_prefix_caching`") print("Results without `enable_prefix_caching`")
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
@ -52,6 +48,15 @@ for output in outputs:
print("-" * 80) print("-" * 80)
# Destroy the LLM object and free up the GPU memory.
del regular_llm
cleanup_dist_env_and_memory()
# Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True,
gpu_memory_utilization=0.4)
# Warmup so that the shared prompt's KV cache is computed. # Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params) prefix_cached_llm.generate(generating_prompts[0], sampling_params)

View File

@ -12,11 +12,11 @@ import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput from vllm.outputs import RequestOutput as RealRequestOutput
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear from ..utils import wait_for_gpu_memory_to_clear
@ -157,7 +157,7 @@ async def async_engine():
engine.shutdown_background_loop() engine.shutdown_background_loop()
del engine del engine
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
cleanup() cleanup_dist_env_and_memory()
@pytest.fixture() @pytest.fixture()

View File

@ -1,5 +1,3 @@
import contextlib
import gc
import json import json
import os import os
import sys import sys
@ -27,8 +25,7 @@ from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.config import TaskOption, TokenizerPoolConfig from vllm.config import TaskOption, TokenizerPoolConfig
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment, from vllm.distributed import (cleanup_dist_env_and_memory,
destroy_model_parallel,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
@ -140,17 +137,7 @@ def dist_init():
) )
initialize_model_parallel(1, 1) initialize_model_parallel(1, 1)
yield yield
cleanup() cleanup_dist_env_and_memory()
def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
if not is_cpu():
torch.cuda.empty_cache()
@pytest.fixture() @pytest.fixture()
@ -167,7 +154,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
def cleanup_fixture(should_do_global_cleanup_after_test: bool): def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield yield
if should_do_global_cleanup_after_test: if should_do_global_cleanup_after_test:
cleanup() cleanup_dist_env_and_memory()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -606,7 +593,7 @@ class HfRunner:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
del self.model del self.model
cleanup() cleanup_dist_env_and_memory()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -861,7 +848,7 @@ class VllmRunner:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
del self.model del self.model
cleanup() cleanup_dist_env_and_memory()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -3,10 +3,9 @@ from typing import Callable, Iterable, Optional
import pytest import pytest
from vllm import LLM from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from ....conftest import cleanup
@pytest.fixture @pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
@ -37,7 +36,7 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
yield llm yield llm
del llm del llm
cleanup() cleanup_dist_env_and_memory()
for llm in generator_inner(): for llm in generator_inner():
yield llm yield llm

View File

@ -4,8 +4,7 @@ from typing import List
import pytest import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import cleanup
MODEL_NAME = "intfloat/e5-mistral-7b-instruct" MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@ -41,7 +40,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput], def assert_outputs_equal(o1: List[EmbeddingRequestOutput],

View File

@ -4,8 +4,7 @@ from typing import List
import pytest import pytest
from vllm import LLM, RequestOutput, SamplingParams from vllm import LLM, RequestOutput, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import cleanup
MODEL_NAME = "facebook/opt-125m" MODEL_NAME = "facebook/opt-125m"
@ -39,7 +38,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):

View File

@ -5,10 +5,9 @@ import pytest
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from vllm import LLM from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
PROMPTS = [ PROMPTS = [
@ -39,7 +38,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -5,12 +5,11 @@ import weakref
import jsonschema import jsonschema
import pytest import pytest
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -23,7 +22,7 @@ def llm():
with llm.deprecate_legacy_api(): with llm.deprecate_legacy_api():
yield weakref.proxy(llm) yield weakref.proxy(llm)
del llm del llm
cleanup() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup

View File

@ -1,6 +1,7 @@
import sys import sys
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
def test_lazy_outlines(sample_regex): def test_lazy_outlines(sample_regex):
@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
] ]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM without guided decoding as a baseline.
llm = LLM(model="facebook/opt-125m", llm = LLM(model="facebook/opt-125m",
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.3) gpu_memory_utilization=0.3)
@ -26,8 +28,11 @@ def test_lazy_outlines(sample_regex):
# make sure outlines is not imported # make sure outlines is not imported
assert 'outlines' not in sys.modules assert 'outlines' not in sys.modules
# The second LLM needs to request a higher gpu_memory_utilization because # Destroy the LLM object and free up the GPU memory.
# the first LLM has already allocated a full 30% of the gpu memory. del llm
cleanup_dist_env_and_memory()
# Create an LLM with guided decoding enabled.
llm = LLM(model="facebook/opt-125m", llm = LLM(model="facebook/opt-125m",
enforce_eager=True, enforce_eager=True,
guided_decoding_backend="lm-format-enforcer", guided_decoding_backend="lm-format-enforcer",

View File

@ -6,8 +6,7 @@ import weakref
import pytest import pytest
from vllm import LLM from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import cleanup
MODEL_NAME = "facebook/opt-125m" MODEL_NAME = "facebook/opt-125m"
@ -27,7 +26,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup

View File

@ -1,20 +1,16 @@
import contextlib
import gc
import tempfile import tempfile
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, TypedDict from typing import Dict, List, TypedDict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import ray
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
import vllm import vllm
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import (destroy_distributed_environment, from vllm.distributed import (cleanup_dist_env_and_memory,
destroy_model_parallel,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -48,16 +44,6 @@ LONG_LORA_INFOS: List[ContextIDInfo] = [{
}] }]
def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
ray.shutdown()
@pytest.fixture() @pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool: def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture. """Allow subdirectories to skip global cleanup by overriding this fixture.
@ -72,7 +58,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
def cleanup_fixture(should_do_global_cleanup_after_test: bool): def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield yield
if should_do_global_cleanup_after_test: if should_do_global_cleanup_after_test:
cleanup() cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture @pytest.fixture
@ -87,7 +73,7 @@ def dist_init():
) )
initialize_model_parallel(1, 1) initialize_model_parallel(1, 1)
yield yield
cleanup() cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture @pytest.fixture
@ -238,7 +224,7 @@ def long_context_lora_files_32k():
def long_context_infos(long_context_lora_files_16k_1, def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2, long_context_lora_files_16k_2,
long_context_lora_files_32k): long_context_lora_files_32k):
cleanup() cleanup_dist_env_and_memory(shutdown_ray=True)
infos: Dict[int, ContextInfo] = {} infos: Dict[int, ContextInfo] = {}
for lora_checkpoint_info in LONG_LORA_INFOS: for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"] lora_id = lora_checkpoint_info["lora_id"]
@ -259,7 +245,7 @@ def long_context_infos(long_context_lora_files_16k_1,
@pytest.fixture @pytest.fixture
def llama_2_7b_engine_extra_embeddings(): def llama_2_7b_engine_extra_embeddings():
cleanup() cleanup_dist_env_and_memory(shutdown_ray=True)
get_model_old = get_model get_model_old = get_model
def get_model_patched(*, model_config, device_config, **kwargs): def get_model_patched(*, model_config, device_config, **kwargs):
@ -272,7 +258,7 @@ def llama_2_7b_engine_extra_embeddings():
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
yield engine.llm_engine yield engine.llm_engine
del engine del engine
cleanup() cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture @pytest.fixture

View File

@ -3,10 +3,9 @@ from typing import List
import pytest import pytest
import vllm import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from .conftest import cleanup
MODEL_PATH = "baichuan-inc/Baichuan-7B" MODEL_PATH = "baichuan-inc/Baichuan-7B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
@ -80,7 +79,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
del llm_tp1 del llm_tp1
cleanup() cleanup_dist_env_and_memory()
llm_tp2 = vllm.LLM(MODEL_PATH, llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True, enable_lora=True,
@ -93,7 +92,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
del llm_tp2 del llm_tp2
cleanup() cleanup_dist_env_and_memory()
assert output_tp1 == output_tp2 assert output_tp1 == output_tp2
@ -108,6 +107,6 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
del llm_tp4 del llm_tp4
cleanup() cleanup_dist_env_and_memory()
assert output_tp1 == output_tp4 assert output_tp1 == output_tp4

View File

@ -4,10 +4,9 @@ import pytest
import ray import ray
import vllm import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from .conftest import cleanup
MODEL_PATH = "meta-llama/Llama-2-7b-hf" MODEL_PATH = "meta-llama/Llama-2-7b-hf"
@ -93,7 +92,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
del llm_tp1 del llm_tp1
cleanup() cleanup_dist_env_and_memory()
llm_tp2 = vllm.LLM(MODEL_PATH, llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True, enable_lora=True,
@ -103,7 +102,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
del llm_tp2 del llm_tp2
cleanup() cleanup_dist_env_and_memory()
assert output_tp1 == output_tp2 assert output_tp1 == output_tp2
@ -115,7 +114,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
del llm_tp4 del llm_tp4
cleanup() cleanup_dist_env_and_memory()
assert output_tp1 == output_tp4 assert output_tp1 == output_tp4

View File

@ -6,11 +6,10 @@ from typing import List
import pytest import pytest
import vllm import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import is_hip from vllm.utils import is_hip
from .conftest import cleanup
@dataclass @dataclass
class ModelWithQuantization: class ModelWithQuantization:
@ -160,7 +159,7 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
print("removing lora") print("removing lora")
del llm del llm
cleanup() cleanup_dist_env_and_memory()
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@ -181,7 +180,7 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
del llm_tp1 del llm_tp1
cleanup() cleanup_dist_env_and_memory()
llm_tp2 = vllm.LLM( llm_tp2 = vllm.LLM(
model=model.model_path, model=model.model_path,
@ -194,6 +193,6 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
del llm_tp2 del llm_tp2
cleanup() cleanup_dist_env_and_memory()
assert output_tp1 == output_tp2 assert output_tp1 == output_tp2

View File

@ -6,13 +6,12 @@ import ray
from prometheus_client import REGISTRY from prometheus_client import REGISTRY
from vllm import EngineArgs, LLMEngine from vllm import EngineArgs, LLMEngine
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import RayPrometheusStatLogger from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from ..conftest import cleanup
MODELS = [ MODELS = [
"facebook/opt-125m", "facebook/opt-125m",
] ]
@ -307,7 +306,7 @@ def test_metric_spec_decode_interval(
finally: finally:
del engine del engine
cleanup() cleanup_dist_env_and_memory()
def assert_metrics(engine: LLMEngine, disable_log_stats: bool, def assert_metrics(engine: LLMEngine, disable_log_stats: bool,

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from ....conftest import _ImageAssets, cleanup from ....conftest import _ImageAssets
# we use snapshot_download to prevent conflicts between # we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner # dynamic_module and trust_remote_code for hf_runner
@ -45,12 +45,13 @@ def run_intern_vit_test(
for pixel_value in pixel_values for pixel_value in pixel_values
] ]
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.models.intern_vit import InternVisionModel
vllm_model = InternVisionModel(config) vllm_model = InternVisionModel(config)
vllm_model.load_weights(hf_model.state_dict().items()) vllm_model.load_weights(hf_model.state_dict().items())
del hf_model del hf_model
cleanup() cleanup_dist_env_and_memory()
vllm_model = vllm_model.to("cuda", dtype) vllm_model = vllm_model.to("cuda", dtype)
vllm_outputs_per_image = [ vllm_outputs_per_image = [
@ -58,7 +59,7 @@ def run_intern_vit_test(
for pixel_value in pixel_values for pixel_value in pixel_values
] ]
del vllm_model del vllm_model
cleanup() cleanup_dist_env_and_memory()
cos_similar = nn.CosineSimilarity(dim=-1) cos_similar = nn.CosineSimilarity(dim=-1)
for vllm_output, hf_output in zip(vllm_outputs_per_image, for vllm_output, hf_output in zip(vllm_outputs_per_image,

View File

@ -4,8 +4,8 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
""" """
import pytest import pytest
from tests.conftest import cleanup
from vllm import LLM from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_LEN_LEN = [ MODEL_LEN_LEN = [
# Example models with sliding window. # Example models with sliding window.
@ -31,7 +31,7 @@ def test_disable_sliding_window(model_len_len, ):
model_config.max_model_len) model_config.max_model_len)
del vllm_disabled_model del vllm_disabled_model
cleanup() cleanup_dist_env_and_memory()
vllm_enabled_model = LLM(model, disable_sliding_window=False) vllm_enabled_model = LLM(model, disable_sliding_window=False)
vllm_enabled_model.generate("Hi my name is") vllm_enabled_model.generate("Hi my name is")
@ -41,4 +41,4 @@ def test_disable_sliding_window(model_len_len, ):
model_config.max_model_len) model_config.max_model_len)
del vllm_enabled_model del vllm_enabled_model
cleanup() cleanup_dist_env_and_memory()

View File

@ -4,10 +4,10 @@ from typing import List, Optional, Sequence, Tuple, Union
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import PromptLogprobs, SampleLogprobs from vllm.sequence import PromptLogprobs, SampleLogprobs
from ...conftest import cleanup
from ...models.utils import (TokensTextLogprobs, from ...models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs, TokensTextLogprobsPromptLogprobs,
check_logprobs_close, check_outputs_equal) check_logprobs_close, check_outputs_equal)
@ -44,7 +44,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
yield llm yield llm
del llm del llm
cleanup() cleanup_dist_env_and_memory()
return generate return generate

View File

@ -1,27 +1,18 @@
import contextlib
import functools import functools
import gc import gc
from typing import Callable, TypeVar from typing import Callable, TypeVar
import pytest import pytest
import ray
import torch import torch
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from vllm.distributed import (destroy_distributed_environment, from vllm.distributed import cleanup_dist_env_and_memory
destroy_model_parallel)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cleanup(): def cleanup():
destroy_model_parallel() cleanup_dist_env_and_memory(shutdown_ray=True)
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
ray.shutdown()
gc.collect()
torch.cuda.empty_cache()
_P = ParamSpec("_P") _P = ParamSpec("_P")

View File

@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
steps. steps.
""" """
import contextlib import contextlib
import gc
import pickle import pickle
import weakref import weakref
from collections import namedtuple from collections import namedtuple
@ -36,7 +37,7 @@ from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import supports_custom_op from vllm.utils import is_cpu, supports_custom_op
@dataclass @dataclass
@ -1129,6 +1130,19 @@ def destroy_distributed_environment():
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
if shutdown_ray:
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
if not is_cpu():
torch.cuda.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
""" """
This is a collective operation that returns if each rank is in the same node This is a collective operation that returns if each rank is in the same node