
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Travis Johnson <tjohnson31415@gmail.com>
55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
import os
|
|
import tempfile
|
|
|
|
import huggingface_hub.constants
|
|
import pytest
|
|
from huggingface_hub.utils import LocalEntryNotFoundError
|
|
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
download_weights_from_hf, enable_hf_transfer)
|
|
|
|
|
|
def test_hf_transfer_auto_activation():
|
|
if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
|
|
# in case it is already set, we can't test the auto activation
|
|
pytest.skip(
|
|
"HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation")
|
|
enable_hf_transfer()
|
|
try:
|
|
# enable hf hub transfer if available
|
|
import hf_transfer # type: ignore # noqa
|
|
HF_TRANFER_ACTIVE = True
|
|
except ImportError:
|
|
HF_TRANFER_ACTIVE = False
|
|
assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER ==
|
|
HF_TRANFER_ACTIVE)
|
|
|
|
|
|
def test_download_weights_from_hf():
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# assert LocalEntryNotFoundError error is thrown
|
|
# if offline is set and model is not cached
|
|
huggingface_hub.constants.HF_HUB_OFFLINE = True
|
|
with pytest.raises(LocalEntryNotFoundError):
|
|
download_weights_from_hf("facebook/opt-125m",
|
|
allow_patterns=["*.safetensors", "*.bin"],
|
|
cache_dir=tmpdir)
|
|
|
|
# download the model
|
|
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
|
download_weights_from_hf("facebook/opt-125m",
|
|
allow_patterns=["*.safetensors", "*.bin"],
|
|
cache_dir=tmpdir)
|
|
|
|
# now it should work offline
|
|
huggingface_hub.constants.HF_HUB_OFFLINE = True
|
|
assert download_weights_from_hf(
|
|
"facebook/opt-125m",
|
|
allow_patterns=["*.safetensors", "*.bin"],
|
|
cache_dir=tmpdir) is not None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_hf_transfer_auto_activation()
|
|
test_download_weights_from_hf()
|