[Misc] Log time consumption on weight downloading (#12926)

This commit is contained in:
Jun Duan 2025-02-08 04:16:42 -05:00 committed by GitHub
parent 7e1837676a
commit 011e612d92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@ import hashlib
import json import json
import os import os
import tempfile import tempfile
import time
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
@ -14,7 +15,8 @@ import gguf
import huggingface_hub.constants import huggingface_hub.constants
import numpy as np import numpy as np
import torch import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import (HfFileSystem, hf_hub_download, scan_cache_dir,
snapshot_download)
from safetensors.torch import load_file, safe_open, save_file from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -253,6 +255,8 @@ def download_weights_from_hf(
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):
start_size = scan_cache_dir().size_on_disk
start_time = time.perf_counter()
hf_folder = snapshot_download( hf_folder = snapshot_download(
model_name_or_path, model_name_or_path,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
@ -262,6 +266,11 @@ def download_weights_from_hf(
revision=revision, revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
) )
end_time = time.perf_counter()
end_size = scan_cache_dir().size_on_disk
if end_size != start_size:
logger.info("Time took to download weights for %s: %.6f seconds",
model_name_or_path, end_time - start_time)
return hf_folder return hf_folder