[Frontend] Remove custom_cache_manager (#13791)
Signed-off-by: fulvius31 <asangior@redhat.com>
This commit is contained in:
parent
a4d83661d7
commit
374ee287d8
@ -16,12 +16,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.triton_utils.importing import HAS_TRITON
|
|
||||||
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
|
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
|
||||||
|
|
||||||
if HAS_TRITON:
|
|
||||||
from vllm.triton_utils import maybe_set_triton_cache_manager
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
@ -314,7 +310,3 @@ def set_multiprocessing_worker_envs(parallel_config):
|
|||||||
current_parallelism, default_omp_num_threads)
|
current_parallelism, default_omp_num_threads)
|
||||||
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
|
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
|
||||||
torch.set_num_threads(default_omp_num_threads)
|
torch.set_num_threads(default_omp_num_threads)
|
||||||
|
|
||||||
# workaround for https://github.com/vllm-project/vllm/issues/6103
|
|
||||||
if HAS_TRITON and parallel_config.world_size > 1:
|
|
||||||
maybe_set_triton_cache_manager()
|
|
||||||
|
@ -3,10 +3,3 @@
|
|||||||
from vllm.triton_utils.importing import HAS_TRITON
|
from vllm.triton_utils.importing import HAS_TRITON
|
||||||
|
|
||||||
__all__ = ["HAS_TRITON"]
|
__all__ = ["HAS_TRITON"]
|
||||||
|
|
||||||
if HAS_TRITON:
|
|
||||||
|
|
||||||
from vllm.triton_utils.custom_cache_manager import (
|
|
||||||
maybe_set_triton_cache_manager)
|
|
||||||
|
|
||||||
__all__ += ["maybe_set_triton_cache_manager"]
|
|
||||||
|
@ -1,55 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from triton.runtime.cache import (FileCacheManager, default_cache_dir,
|
|
||||||
default_dump_dir, default_override_dir)
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_set_triton_cache_manager() -> None:
|
|
||||||
"""Set environment variable to tell Triton to use a
|
|
||||||
custom cache manager"""
|
|
||||||
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
|
|
||||||
if cache_manger is None:
|
|
||||||
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
|
|
||||||
logger.info("Setting Triton cache manager to: %s", manager)
|
|
||||||
os.environ["TRITON_CACHE_MANAGER"] = manager
|
|
||||||
|
|
||||||
|
|
||||||
class CustomCacheManager(FileCacheManager):
|
|
||||||
"""Re-implements Triton's cache manager, ensuring that a
|
|
||||||
unique cache directory is created for each process. This is
|
|
||||||
needed to avoid collisions when running with tp>1 and
|
|
||||||
using multi-processing as the distributed backend.
|
|
||||||
|
|
||||||
Note this issue was fixed by triton-lang/triton/pull/4295,
|
|
||||||
but the fix is not yet included in triton==v3.0.0. However,
|
|
||||||
it should be included in the subsequent version.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, key, override=False, dump=False):
|
|
||||||
self.key = key
|
|
||||||
self.lock_path = None
|
|
||||||
if dump:
|
|
||||||
self.cache_dir = default_dump_dir()
|
|
||||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
||||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
||||||
os.makedirs(self.cache_dir, exist_ok=True)
|
|
||||||
elif override:
|
|
||||||
self.cache_dir = default_override_dir()
|
|
||||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
||||||
else:
|
|
||||||
# create cache directory if it doesn't exist
|
|
||||||
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
|
|
||||||
"").strip() or default_cache_dir()
|
|
||||||
if self.cache_dir:
|
|
||||||
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
|
|
||||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
||||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
||||||
os.makedirs(self.cache_dir, exist_ok=True)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Could not create or locate cache dir")
|
|
Loading…
x
Reference in New Issue
Block a user