[MISC] Consolidate cleanup() and refactor offline_inference_with_prefix.py (#9510)
This commit is contained in:
parent
9bb10a7d27
commit
d11bf435a0
@ -1,4 +1,5 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
# NOTE: This is just a running example. For benchmarking purpose,
|
||||
# please see benchmarks/benchmark_prefix_caching.py
|
||||
@ -28,14 +29,9 @@ generating_prompts = [prefix + prompt for prompt in prompts]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.0)
|
||||
|
||||
# Create an LLM.
|
||||
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3)
|
||||
# Create an LLM without prefix caching as a baseline.
|
||||
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)
|
||||
|
||||
# The second LLM needs to request a higher gpu_memory_utilization because
|
||||
# the first LLM has already allocated a full 30% of the gpu memory.
|
||||
prefix_cached_llm = LLM(model="facebook/opt-125m",
|
||||
enable_prefix_caching=True,
|
||||
gpu_memory_utilization=0.6)
|
||||
print("Results without `enable_prefix_caching`")
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
@ -52,6 +48,15 @@ for output in outputs:
|
||||
|
||||
print("-" * 80)
|
||||
|
||||
# Destroy the LLM object and free up the GPU memory.
|
||||
del regular_llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Create an LLM with prefix caching enabled.
|
||||
prefix_cached_llm = LLM(model="facebook/opt-125m",
|
||||
enable_prefix_caching=True,
|
||||
gpu_memory_utilization=0.4)
|
||||
|
||||
# Warmup so that the shared prompt's KV cache is computed.
|
||||
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
|
||||
|
||||
|
@ -12,11 +12,11 @@ import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput as RealRequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
|
||||
from ..conftest import cleanup
|
||||
from ..utils import wait_for_gpu_memory_to_clear
|
||||
|
||||
|
||||
@ -157,7 +157,7 @@ async def async_engine():
|
||||
engine.shutdown_background_loop()
|
||||
del engine
|
||||
await asyncio.sleep(0.1)
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -1,5 +1,3 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@ -27,8 +25,7 @@ from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import TaskOption, TokenizerPoolConfig
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
@ -140,17 +137,7 @@ def dist_init():
|
||||
)
|
||||
initialize_model_parallel(1, 1)
|
||||
yield
|
||||
cleanup()
|
||||
|
||||
|
||||
def cleanup():
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
gc.collect()
|
||||
if not is_cpu():
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -167,7 +154,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
|
||||
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
yield
|
||||
if should_do_global_cleanup_after_test:
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -606,7 +593,7 @@ class HfRunner:
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
del self.model
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -861,7 +848,7 @@ class VllmRunner:
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
del self.model
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -3,10 +3,9 @@ from typing import Callable, Iterable, Optional
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
from ....conftest import cleanup
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
@ -37,7 +36,7 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
|
||||
yield llm
|
||||
del llm
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
|
@ -4,8 +4,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
|
||||
|
||||
from ...conftest import cleanup
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||
|
||||
@ -41,7 +40,7 @@ def llm():
|
||||
|
||||
del llm
|
||||
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
|
||||
|
@ -4,8 +4,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, RequestOutput, SamplingParams
|
||||
|
||||
from ...conftest import cleanup
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
@ -39,7 +38,7 @@ def llm():
|
||||
|
||||
del llm
|
||||
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
|
||||
|
@ -5,10 +5,9 @@ import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ...conftest import cleanup
|
||||
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
PROMPTS = [
|
||||
@ -39,7 +38,7 @@ def llm():
|
||||
|
||||
del llm
|
||||
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -5,12 +5,11 @@ import weakref
|
||||
import jsonschema
|
||||
import pytest
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
from ...conftest import cleanup
|
||||
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@ -23,7 +22,7 @@ def llm():
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
del llm
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
@ -1,6 +1,7 @@
|
||||
import sys
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
|
||||
def test_lazy_outlines(sample_regex):
|
||||
@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM without guided decoding as a baseline.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.3)
|
||||
@ -26,8 +28,11 @@ def test_lazy_outlines(sample_regex):
|
||||
# make sure outlines is not imported
|
||||
assert 'outlines' not in sys.modules
|
||||
|
||||
# The second LLM needs to request a higher gpu_memory_utilization because
|
||||
# the first LLM has already allocated a full 30% of the gpu memory.
|
||||
# Destroy the LLM object and free up the GPU memory.
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Create an LLM with guided decoding enabled.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
|
@ -6,8 +6,7 @@ import weakref
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
from ...conftest import cleanup
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
@ -27,7 +26,7 @@ def llm():
|
||||
|
||||
del llm
|
||||
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
@ -1,20 +1,16 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -48,16 +44,6 @@ LONG_LORA_INFOS: List[ContextIDInfo] = [{
|
||||
}]
|
||||
|
||||
|
||||
def cleanup():
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def should_do_global_cleanup_after_test(request) -> bool:
|
||||
"""Allow subdirectories to skip global cleanup by overriding this fixture.
|
||||
@ -72,7 +58,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
|
||||
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
yield
|
||||
if should_do_global_cleanup_after_test:
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -87,7 +73,7 @@ def dist_init():
|
||||
)
|
||||
initialize_model_parallel(1, 1)
|
||||
yield
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -238,7 +224,7 @@ def long_context_lora_files_32k():
|
||||
def long_context_infos(long_context_lora_files_16k_1,
|
||||
long_context_lora_files_16k_2,
|
||||
long_context_lora_files_32k):
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
infos: Dict[int, ContextInfo] = {}
|
||||
for lora_checkpoint_info in LONG_LORA_INFOS:
|
||||
lora_id = lora_checkpoint_info["lora_id"]
|
||||
@ -259,7 +245,7 @@ def long_context_infos(long_context_lora_files_16k_1,
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings():
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
get_model_old = get_model
|
||||
|
||||
def get_model_patched(*, model_config, device_config, **kwargs):
|
||||
@ -272,7 +258,7 @@ def llama_2_7b_engine_extra_embeddings():
|
||||
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||
yield engine.llm_engine
|
||||
del engine
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -3,10 +3,9 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from .conftest import cleanup
|
||||
|
||||
MODEL_PATH = "baichuan-inc/Baichuan-7B"
|
||||
|
||||
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
|
||||
@ -80,7 +79,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
|
||||
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp1
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
llm_tp2 = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
@ -93,7 +92,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
|
||||
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
|
||||
|
||||
del llm_tp2
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
assert output_tp1 == output_tp2
|
||||
|
||||
@ -108,6 +107,6 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
|
||||
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
|
||||
|
||||
del llm_tp4
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
assert output_tp1 == output_tp4
|
||||
|
@ -4,10 +4,9 @@ import pytest
|
||||
import ray
|
||||
|
||||
import vllm
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from .conftest import cleanup
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
|
||||
@ -93,7 +92,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
|
||||
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp1
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
llm_tp2 = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
@ -103,7 +102,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
|
||||
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp2
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
assert output_tp1 == output_tp2
|
||||
|
||||
@ -115,7 +114,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
|
||||
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp4
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
assert output_tp1 == output_tp4
|
||||
|
||||
|
@ -6,11 +6,10 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .conftest import cleanup
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWithQuantization:
|
||||
@ -160,7 +159,7 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
|
||||
print("removing lora")
|
||||
|
||||
del llm
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@ -181,7 +180,7 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
|
||||
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp1
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
llm_tp2 = vllm.LLM(
|
||||
model=model.model_path,
|
||||
@ -194,6 +193,6 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
|
||||
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp2
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
assert output_tp1 == output_tp2
|
||||
|
@ -6,13 +6,12 @@ import ray
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
from vllm import EngineArgs, LLMEngine
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.metrics import RayPrometheusStatLogger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ..conftest import cleanup
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
]
|
||||
@ -307,7 +306,7 @@ def test_metric_spec_decode_interval(
|
||||
|
||||
finally:
|
||||
del engine
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
|
||||
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||
|
||||
from ....conftest import _ImageAssets, cleanup
|
||||
from ....conftest import _ImageAssets
|
||||
|
||||
# we use snapshot_download to prevent conflicts between
|
||||
# dynamic_module and trust_remote_code for hf_runner
|
||||
@ -45,12 +45,13 @@ def run_intern_vit_test(
|
||||
for pixel_value in pixel_values
|
||||
]
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
vllm_model = InternVisionModel(config)
|
||||
vllm_model.load_weights(hf_model.state_dict().items())
|
||||
|
||||
del hf_model
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
vllm_model = vllm_model.to("cuda", dtype)
|
||||
vllm_outputs_per_image = [
|
||||
@ -58,7 +59,7 @@ def run_intern_vit_test(
|
||||
for pixel_value in pixel_values
|
||||
]
|
||||
del vllm_model
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
cos_similar = nn.CosineSimilarity(dim=-1)
|
||||
for vllm_output, hf_output in zip(vllm_outputs_per_image,
|
||||
|
@ -4,8 +4,8 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
MODEL_LEN_LEN = [
|
||||
# Example models with sliding window.
|
||||
@ -31,7 +31,7 @@ def test_disable_sliding_window(model_len_len, ):
|
||||
model_config.max_model_len)
|
||||
|
||||
del vllm_disabled_model
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
vllm_enabled_model = LLM(model, disable_sliding_window=False)
|
||||
vllm_enabled_model.generate("Hi my name is")
|
||||
@ -41,4 +41,4 @@ def test_disable_sliding_window(model_len_len, ):
|
||||
model_config.max_model_len)
|
||||
|
||||
del vllm_enabled_model
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
@ -4,10 +4,10 @@ from typing import List, Optional, Sequence, Tuple, Union
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
|
||||
from ...conftest import cleanup
|
||||
from ...models.utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
check_logprobs_close, check_outputs_equal)
|
||||
@ -44,7 +44,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
yield llm
|
||||
|
||||
del llm
|
||||
cleanup()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
return generate
|
||||
|
||||
|
@ -1,27 +1,18 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup():
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
ray.shutdown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
|
@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
|
||||
steps.
|
||||
"""
|
||||
import contextlib
|
||||
import gc
|
||||
import pickle
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
@ -36,7 +37,7 @@ from torch.distributed import Backend, ProcessGroup
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import supports_custom_op
|
||||
from vllm.utils import is_cpu, supports_custom_op
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1129,6 +1130,19 @@ def destroy_distributed_environment():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
if shutdown_ray:
|
||||
import ray # Lazy import Ray
|
||||
ray.shutdown()
|
||||
gc.collect()
|
||||
if not is_cpu():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
||||
"""
|
||||
This is a collective operation that returns if each rank is in the same node
|
||||
|
Loading…
x
Reference in New Issue
Block a user