[Core] Integrate fastsafetensors loader for loading model weights (#10647)

Signed-off-by: Manish Sethi <Manish.sethi1@ibm.com>
This commit is contained in:
Manish Sethi 2025-03-24 11:08:02 -04:00 committed by GitHub
parent 9606d572ed
commit 761702fd19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 152 additions and 9 deletions

View File

@ -0,0 +1,5 @@
Loading Model weights with fastsafetensors
===================================================================
Using fastsafetensor library enables loading model weights to GPU memory by leveraging GPU direct storage. See https://github.com/foundation-model-stack/fastsafetensors for more details.
For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``

View File

@ -5,4 +5,5 @@
runai_model_streamer
tensorizer
fastsafetensor
:::

View File

@ -41,3 +41,4 @@ tritonclient==2.51.0
numpy < 2.0.0
runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10

View File

@ -67,6 +67,7 @@ click==8.1.7
# jiwer
# nltk
# ray
# typer
colorama==0.4.6
# via
# awscli
@ -122,6 +123,8 @@ fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
# via cupy-cuda12x
fastsafetensors==0.1.10
# via -r requirements/test.in
filelock==3.16.1
# via
# datasets
@ -505,7 +508,9 @@ requests==2.32.3
responses==0.25.3
# via genai-perf
rich==13.9.4
# via genai-perf
# via
# genai-perf
# typer
rouge-score==0.1.2
# via lm-eval
rpds-py==0.20.1
@ -550,6 +555,8 @@ setuptools==75.8.0
# via
# pytablewriter
# torch
shellingham==1.5.4
# via typer
six==1.16.0
# via
# python-dateutil
@ -600,6 +607,7 @@ torch==2.6.0
# accelerate
# bitsandbytes
# encodec
# fastsafetensors
# lm-eval
# peft
# runai-model-streamer
@ -654,6 +662,8 @@ typepy==1.3.2
# dataproperty
# pytablewriter
# tabledata
typer==0.15.2
# via fastsafetensors
typing-extensions==4.12.2
# via
# huggingface-hub
@ -663,6 +673,7 @@ typing-extensions==4.12.2
# pydantic
# pydantic-core
# torch
# typer
tzdata==2024.2
# via pandas
urllib3==2.2.3

View File

@ -680,6 +680,7 @@ setup(
install_requires=get_requirements(),
extras_require={
"tensorizer": ["tensorizer>=2.9.0"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing
"video": ["decord"] # Required for video processing

View File

View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
from vllm import SamplingParams
from vllm.config import LoadFormat
test_model = "openai-community/gpt2"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def test_model_loader_download_files(vllm_runner):
with vllm_runner(test_model,
load_format=LoadFormat.FASTSAFETENSORS) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs

View File

@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
import glob
import tempfile
import huggingface_hub.constants
import torch
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, fastsafetensors_weights_iterator,
safetensors_weights_iterator)
def test_fastsafetensors_model_loader():
with tempfile.TemporaryDirectory() as tmpdir:
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("openai-community/gpt2",
allow_patterns=["*.safetensors"],
cache_dir=tmpdir)
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
assert len(safetensors) > 0
fastsafetensors_tensors = {}
hf_safetensors_tensors = {}
for name, tensor in fastsafetensors_weights_iterator(
safetensors, True):
fastsafetensors_tensors[name] = tensor
for name, tensor in safetensors_weights_iterator(safetensors, True):
hf_safetensors_tensors[name] = tensor
assert len(fastsafetensors_tensors) == len(hf_safetensors_tensors)
for name, fastsafetensors_tensor in fastsafetensors_tensors.items():
fastsafetensors_tensor = fastsafetensors_tensor.to('cpu')
assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[
name].dtype
assert fastsafetensors_tensor.shape == hf_safetensors_tensors[
name].shape
assert torch.all(
fastsafetensors_tensor.eq(hf_safetensors_tensors[name]))
if __name__ == "__main__":
test_fastsafetensors_model_loader()

View File

@ -1277,6 +1277,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
FASTSAFETENSORS = "fastsafetensors"
@dataclass

View File

@ -49,9 +49,10 @@ from vllm.model_executor.model_loader.utils import (ParamMapping,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, get_gguf_extra_tensor_names,
get_lock, gguf_quant_weights_iterator, initialize_dummy_weights,
np_cache_weights_iterator, pt_weights_iterator,
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@ -275,7 +276,8 @@ class DefaultModelLoader(BaseModelLoader):
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS:
elif (load_format == LoadFormat.SAFETENSORS
or load_format == LoadFormat.FASTSAFETENSORS):
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
@ -357,10 +359,16 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,

View File

@ -38,6 +38,14 @@ except (ImportError, OSError):
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
"SafetensorsStreamer")
try:
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
except ImportError:
fastsafetensors = PlaceholderModule("fastsafetensors")
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
"SafeTensorsFileLoader")
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users
@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator(
yield from streamer.get_tensors()
def fastsafetensors_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if torch.distributed.is_initialized():
pg = torch.distributed.group.WORLD
else:
pg = SingleGroup()
device = torch.device(f'cuda:{pg.rank()}')
weight_files_sub_lists = [
hf_weights_files[i:i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())
]
for f_list in tqdm(
weight_files_sub_lists,
desc="Loading safetensors using Fastsafetensor loader",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
loader = SafeTensorsFileLoader(pg, device)
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def pt_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,