[Misc] Log time consumption on weight downloading (#12926)

This commit is contained in:
Jun Duan 2025-02-08 04:16:42 -05:00 committed by GitHub
parent 7e1837676a
commit 011e612d92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@ import hashlib
import json
import os
import tempfile
import time
from collections import defaultdict
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
@ -14,7 +15,8 @@ import gguf
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from huggingface_hub import (HfFileSystem, hf_hub_download, scan_cache_dir,
snapshot_download)
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
@ -253,6 +255,8 @@ def download_weights_from_hf(
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
start_size = scan_cache_dir().size_on_disk
start_time = time.perf_counter()
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
@ -262,6 +266,11 @@ def download_weights_from_hf(
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
end_time = time.perf_counter()
end_size = scan_cache_dir().size_on_disk
if end_size != start_size:
logger.info("Time took to download weights for %s: %.6f seconds",
model_name_or_path, end_time - start_time)
return hf_folder