vllm/tests/model_executor/weight_utils.py
Prashant Gupta d6e520e170
[Core] Support offline use of local cache for models (#4374)
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Travis Johnson <tjohnson31415@gmail.com>
2024-04-27 09:59:55 -07:00

55 lines
1.9 KiB
Python

import os
import tempfile
import huggingface_hub.constants
import pytest
from huggingface_hub.utils import LocalEntryNotFoundError
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, enable_hf_transfer)
def test_hf_transfer_auto_activation():
if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
# in case it is already set, we can't test the auto activation
pytest.skip(
"HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation")
enable_hf_transfer()
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
HF_TRANFER_ACTIVE = True
except ImportError:
HF_TRANFER_ACTIVE = False
assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER ==
HF_TRANFER_ACTIVE)
def test_download_weights_from_hf():
with tempfile.TemporaryDirectory() as tmpdir:
# assert LocalEntryNotFoundError error is thrown
# if offline is set and model is not cached
huggingface_hub.constants.HF_HUB_OFFLINE = True
with pytest.raises(LocalEntryNotFoundError):
download_weights_from_hf("facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir)
# download the model
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir)
# now it should work offline
huggingface_hub.constants.HF_HUB_OFFLINE = True
assert download_weights_from_hf(
"facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir) is not None
if __name__ == "__main__":
test_hf_transfer_auto_activation()
test_download_weights_from_hf()