[Bugfix] Offline mode fix (#8376)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
1f0c75afa9
commit
f2e263b801
@ -91,6 +91,7 @@ steps:
|
|||||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/openai
|
- pytest -v -s entrypoints/openai
|
||||||
- pytest -v -s entrypoints/test_chat_utils.py
|
- pytest -v -s entrypoints/test_chat_utils.py
|
||||||
|
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||||
|
|
||||||
|
|
||||||
- label: Distributed Tests (4 GPUs) # 10min
|
- label: Distributed Tests (4 GPUs) # 10min
|
||||||
|
0
tests/entrypoints/offline_mode/__init__.py
Normal file
0
tests/entrypoints/offline_mode/__init__.py
Normal file
77
tests/entrypoints/offline_mode/test_offline_mode.py
Normal file
77
tests/entrypoints/offline_mode/test_offline_mode.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""Tests for HF_HUB_OFFLINE mode"""
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
from ...conftest import cleanup
|
||||||
|
|
||||||
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def llm():
|
||||||
|
# pytest caches the fixture so we use weakref.proxy to
|
||||||
|
# enable garbage collection
|
||||||
|
llm = LLM(model=MODEL_NAME,
|
||||||
|
max_num_batched_tokens=4096,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
gpu_memory_utilization=0.10,
|
||||||
|
enforce_eager=True)
|
||||||
|
|
||||||
|
with llm.deprecate_legacy_api():
|
||||||
|
yield weakref.proxy(llm)
|
||||||
|
|
||||||
|
del llm
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_offline_mode(llm: LLM, monkeypatch):
|
||||||
|
# we use the llm fixture to ensure the model files are in-cache
|
||||||
|
del llm
|
||||||
|
|
||||||
|
# Set HF to offline mode and ensure we can still construct an LLM
|
||||||
|
try:
|
||||||
|
monkeypatch.setenv("HF_HUB_OFFLINE", "1")
|
||||||
|
# Need to re-import huggingface_hub and friends to setup offline mode
|
||||||
|
_re_import_modules()
|
||||||
|
# Cached model files should be used in offline mode
|
||||||
|
LLM(model=MODEL_NAME,
|
||||||
|
max_num_batched_tokens=4096,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
gpu_memory_utilization=0.10,
|
||||||
|
enforce_eager=True)
|
||||||
|
finally:
|
||||||
|
# Reset the environment after the test
|
||||||
|
# NB: Assuming tests are run in online mode
|
||||||
|
monkeypatch.delenv("HF_HUB_OFFLINE")
|
||||||
|
_re_import_modules()
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _re_import_modules():
|
||||||
|
hf_hub_module_names = [
|
||||||
|
k for k in sys.modules if k.startswith("huggingface_hub")
|
||||||
|
]
|
||||||
|
transformers_module_names = [
|
||||||
|
k for k in sys.modules if k.startswith("transformers")
|
||||||
|
and not k.startswith("transformers_modules")
|
||||||
|
]
|
||||||
|
|
||||||
|
reload_exception = None
|
||||||
|
for module_name in hf_hub_module_names + transformers_module_names:
|
||||||
|
try:
|
||||||
|
importlib.reload(sys.modules[module_name])
|
||||||
|
except Exception as e:
|
||||||
|
reload_exception = e
|
||||||
|
# Try to continue clean up so that other tests are less likely to
|
||||||
|
# be affected
|
||||||
|
|
||||||
|
# Error this test if reloading a module failed
|
||||||
|
if reload_exception is not None:
|
||||||
|
raise reload_exception
|
@ -4,7 +4,9 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Type, Union
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
from huggingface_hub import file_exists, hf_hub_download
|
import huggingface_hub
|
||||||
|
from huggingface_hub import (file_exists, hf_hub_download,
|
||||||
|
try_to_load_from_cache)
|
||||||
from transformers import GenerationConfig, PretrainedConfig
|
from transformers import GenerationConfig, PretrainedConfig
|
||||||
from transformers.models.auto.image_processing_auto import (
|
from transformers.models.auto.image_processing_auto import (
|
||||||
get_image_processor_config)
|
get_image_processor_config)
|
||||||
@ -70,7 +72,22 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
|
|||||||
if Path(model).exists():
|
if Path(model).exists():
|
||||||
return (Path(model) / config_name).is_file()
|
return (Path(model) / config_name).is_file()
|
||||||
|
|
||||||
|
# Offline mode support: Check if config file is cached already
|
||||||
|
cached_filepath = try_to_load_from_cache(repo_id=model,
|
||||||
|
filename=config_name,
|
||||||
|
revision=revision)
|
||||||
|
if isinstance(cached_filepath, str):
|
||||||
|
# The config file exists in cache- we can continue trying to load
|
||||||
|
return True
|
||||||
|
|
||||||
|
# NB: file_exists will only check for the existence of the config file on
|
||||||
|
# hf_hub. This will fail in offline mode.
|
||||||
|
try:
|
||||||
return file_exists(model, config_name, revision=revision, token=token)
|
return file_exists(model, config_name, revision=revision, token=token)
|
||||||
|
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||||
|
# Don't raise in offline mode, all we know is that we don't have this
|
||||||
|
# file cached.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
@ -102,6 +119,15 @@ def get_config(
|
|||||||
token=kwargs.get("token")):
|
token=kwargs.get("token")):
|
||||||
config_format = ConfigFormat.MISTRAL
|
config_format = ConfigFormat.MISTRAL
|
||||||
else:
|
else:
|
||||||
|
# If we're in offline mode and found no valid config format, then
|
||||||
|
# raise an offline mode error to indicate to the user that they
|
||||||
|
# don't have files cached and may need to go online.
|
||||||
|
# This is conveniently triggered by calling file_exists().
|
||||||
|
file_exists(model,
|
||||||
|
HF_CONFIG_NAME,
|
||||||
|
revision=revision,
|
||||||
|
token=kwargs.get("token"))
|
||||||
|
|
||||||
raise ValueError(f"No supported config format found in {model}")
|
raise ValueError(f"No supported config format found in {model}")
|
||||||
|
|
||||||
if config_format == ConfigFormat.HF:
|
if config_format == ConfigFormat.HF:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user