[Bugfix] Offline mode fix (#8376)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2024-09-12 12:11:57 -06:00 committed by GitHub
parent 1f0c75afa9
commit f2e263b801
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 106 additions and 2 deletions

View File

@ -91,6 +91,7 @@ steps:
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/openai
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
- label: Distributed Tests (4 GPUs) # 10min

View File

@ -0,0 +1,77 @@
"""Tests for HF_HUB_OFFLINE mode"""
import importlib
import sys
import weakref
import pytest
from vllm import LLM
from ...conftest import cleanup
MODEL_NAME = "facebook/opt-125m"
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
@pytest.mark.skip_global_cleanup
def test_offline_mode(llm: LLM, monkeypatch):
# we use the llm fixture to ensure the model files are in-cache
del llm
# Set HF to offline mode and ensure we can still construct an LLM
try:
monkeypatch.setenv("HF_HUB_OFFLINE", "1")
# Need to re-import huggingface_hub and friends to setup offline mode
_re_import_modules()
# Cached model files should be used in offline mode
LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
finally:
# Reset the environment after the test
# NB: Assuming tests are run in online mode
monkeypatch.delenv("HF_HUB_OFFLINE")
_re_import_modules()
pass
def _re_import_modules():
hf_hub_module_names = [
k for k in sys.modules if k.startswith("huggingface_hub")
]
transformers_module_names = [
k for k in sys.modules if k.startswith("transformers")
and not k.startswith("transformers_modules")
]
reload_exception = None
for module_name in hf_hub_module_names + transformers_module_names:
try:
importlib.reload(sys.modules[module_name])
except Exception as e:
reload_exception = e
# Try to continue clean up so that other tests are less likely to
# be affected
# Error this test if reloading a module failed
if reload_exception is not None:
raise reload_exception

View File

@ -4,7 +4,9 @@ import json
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
from huggingface_hub import file_exists, hf_hub_download
import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download,
try_to_load_from_cache)
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
@ -70,7 +72,22 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
if Path(model).exists():
return (Path(model) / config_name).is_file()
return file_exists(model, config_name, revision=revision, token=token)
# Offline mode support: Check if config file is cached already
cached_filepath = try_to_load_from_cache(repo_id=model,
filename=config_name,
revision=revision)
if isinstance(cached_filepath, str):
# The config file exists in cache- we can continue trying to load
return True
# NB: file_exists will only check for the existence of the config file on
# hf_hub. This will fail in offline mode.
try:
return file_exists(model, config_name, revision=revision, token=token)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode, all we know is that we don't have this
# file cached.
return False
def get_config(
@ -102,6 +119,15 @@ def get_config(
token=kwargs.get("token")):
config_format = ConfigFormat.MISTRAL
else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=kwargs.get("token"))
raise ValueError(f"No supported config format found in {model}")
if config_format == ConfigFormat.HF: