From 761702fd19d15f472a7d57a3adbe0e9b5ae0bf98 Mon Sep 17 00:00:00 2001 From: Manish Sethi Date: Mon, 24 Mar 2025 11:08:02 -0400 Subject: [PATCH] [Core] Integrate `fastsafetensors` loader for loading model weights (#10647) Signed-off-by: Manish Sethi --- .../models/extensions/fastsafetensor.md | 5 ++ docs/source/models/extensions/index.md | 1 + requirements/test.in | 1 + requirements/test.txt | 13 ++++- setup.py | 1 + tests/fastsafetensors_loader/__init__.py | 0 .../test_fastsafetensors_loader.py | 22 +++++++++ .../test_weight_utils.py | 46 ++++++++++++++++++ vllm/config.py | 1 + vllm/model_executor/model_loader/loader.py | 24 ++++++---- .../model_loader/weight_utils.py | 47 +++++++++++++++++++ 11 files changed, 152 insertions(+), 9 deletions(-) create mode 100644 docs/source/models/extensions/fastsafetensor.md create mode 100644 tests/fastsafetensors_loader/__init__.py create mode 100644 tests/fastsafetensors_loader/test_fastsafetensors_loader.py create mode 100644 tests/fastsafetensors_loader/test_weight_utils.py diff --git a/docs/source/models/extensions/fastsafetensor.md b/docs/source/models/extensions/fastsafetensor.md new file mode 100644 index 00000000..66cd710c --- /dev/null +++ b/docs/source/models/extensions/fastsafetensor.md @@ -0,0 +1,5 @@ +Loading Model weights with fastsafetensors +=================================================================== + +Using fastsafetensor library enables loading model weights to GPU memory by leveraging GPU direct storage. See https://github.com/foundation-model-stack/fastsafetensors for more details. +For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true`` diff --git a/docs/source/models/extensions/index.md b/docs/source/models/extensions/index.md index 69faf472..cdcdaa5b 100644 --- a/docs/source/models/extensions/index.md +++ b/docs/source/models/extensions/index.md @@ -5,4 +5,5 @@ runai_model_streamer tensorizer +fastsafetensor ::: diff --git a/requirements/test.in b/requirements/test.in index e75f15c0..5c59bbd1 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -41,3 +41,4 @@ tritonclient==2.51.0 numpy < 2.0.0 runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 +fastsafetensors>=0.1.10 diff --git a/requirements/test.txt b/requirements/test.txt index c733364f..b0ae4796 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -67,6 +67,7 @@ click==8.1.7 # jiwer # nltk # ray + # typer colorama==0.4.6 # via # awscli @@ -122,6 +123,8 @@ fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 # via cupy-cuda12x +fastsafetensors==0.1.10 + # via -r requirements/test.in filelock==3.16.1 # via # datasets @@ -505,7 +508,9 @@ requests==2.32.3 responses==0.25.3 # via genai-perf rich==13.9.4 - # via genai-perf + # via + # genai-perf + # typer rouge-score==0.1.2 # via lm-eval rpds-py==0.20.1 @@ -550,6 +555,8 @@ setuptools==75.8.0 # via # pytablewriter # torch +shellingham==1.5.4 + # via typer six==1.16.0 # via # python-dateutil @@ -600,6 +607,7 @@ torch==2.6.0 # accelerate # bitsandbytes # encodec + # fastsafetensors # lm-eval # peft # runai-model-streamer @@ -654,6 +662,8 @@ typepy==1.3.2 # dataproperty # pytablewriter # tabledata +typer==0.15.2 + # via fastsafetensors typing-extensions==4.12.2 # via # huggingface-hub @@ -663,6 +673,7 @@ typing-extensions==4.12.2 # pydantic # pydantic-core # torch + # typer tzdata==2024.2 # via pandas urllib3==2.2.3 diff --git a/setup.py b/setup.py index 6c45413c..37f3e789 100755 --- a/setup.py +++ b/setup.py @@ -680,6 +680,7 @@ setup( install_requires=get_requirements(), extras_require={ "tensorizer": ["tensorizer>=2.9.0"], + "fastsafetensors": ["fastsafetensors >= 0.1.10"], "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], "audio": ["librosa", "soundfile"], # Required for audio processing "video": ["decord"] # Required for video processing diff --git a/tests/fastsafetensors_loader/__init__.py b/tests/fastsafetensors_loader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fastsafetensors_loader/test_fastsafetensors_loader.py b/tests/fastsafetensors_loader/test_fastsafetensors_loader.py new file mode 100644 index 00000000..184bee2a --- /dev/null +++ b/tests/fastsafetensors_loader/test_fastsafetensors_loader.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import SamplingParams +from vllm.config import LoadFormat + +test_model = "openai-community/gpt2" + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + + +def test_model_loader_download_files(vllm_runner): + with vllm_runner(test_model, + load_format=LoadFormat.FASTSAFETENSORS) as llm: + deserialized_outputs = llm.generate(prompts, sampling_params) + assert deserialized_outputs diff --git a/tests/fastsafetensors_loader/test_weight_utils.py b/tests/fastsafetensors_loader/test_weight_utils.py new file mode 100644 index 00000000..8772035a --- /dev/null +++ b/tests/fastsafetensors_loader/test_weight_utils.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 + +import glob +import tempfile + +import huggingface_hub.constants +import torch + +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, fastsafetensors_weights_iterator, + safetensors_weights_iterator) + + +def test_fastsafetensors_model_loader(): + with tempfile.TemporaryDirectory() as tmpdir: + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf("openai-community/gpt2", + allow_patterns=["*.safetensors"], + cache_dir=tmpdir) + safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) + assert len(safetensors) > 0 + + fastsafetensors_tensors = {} + hf_safetensors_tensors = {} + + for name, tensor in fastsafetensors_weights_iterator( + safetensors, True): + fastsafetensors_tensors[name] = tensor + + for name, tensor in safetensors_weights_iterator(safetensors, True): + hf_safetensors_tensors[name] = tensor + + assert len(fastsafetensors_tensors) == len(hf_safetensors_tensors) + + for name, fastsafetensors_tensor in fastsafetensors_tensors.items(): + fastsafetensors_tensor = fastsafetensors_tensor.to('cpu') + assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[ + name].dtype + assert fastsafetensors_tensor.shape == hf_safetensors_tensors[ + name].shape + assert torch.all( + fastsafetensors_tensor.eq(hf_safetensors_tensors[name])) + + +if __name__ == "__main__": + test_fastsafetensors_model_loader() diff --git a/vllm/config.py b/vllm/config.py index 2fd0db4e..989e5b47 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1277,6 +1277,7 @@ class LoadFormat(str, enum.Enum): BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" RUNAI_STREAMER = "runai_streamer" + FASTSAFETENSORS = "fastsafetensors" @dataclass diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d3f7a26e..de04c6f8 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -49,9 +49,10 @@ from vllm.model_executor.model_loader.utils import (ParamMapping, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, - filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, - get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator, - initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, + fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, + get_lock, gguf_quant_weights_iterator, initialize_dummy_weights, + np_cache_weights_iterator, pt_weights_iterator, runai_safetensors_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -275,7 +276,8 @@ class DefaultModelLoader(BaseModelLoader): # Some quantized models use .pt files for storing the weights. if load_format == LoadFormat.AUTO: allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == LoadFormat.SAFETENSORS: + elif (load_format == LoadFormat.SAFETENSORS + or load_format == LoadFormat.FASTSAFETENSORS): use_safetensors = True allow_patterns = ["*.safetensors"] elif load_format == LoadFormat.MISTRAL: @@ -357,10 +359,16 @@ class DefaultModelLoader(BaseModelLoader): self.load_config.use_tqdm_on_load, ) elif use_safetensors: - weights_iterator = safetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) + if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: + weights_iterator = fastsafetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) else: weights_iterator = pt_weights_iterator( hf_weights_files, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 926172a1..a7475941 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -38,6 +38,14 @@ except (ImportError, OSError): SafetensorsStreamer = runai_model_streamer.placeholder_attr( "SafetensorsStreamer") +try: + from fastsafetensors import SafeTensorsFileLoader, SingleGroup +except ImportError: + fastsafetensors = PlaceholderModule("fastsafetensors") + SafeTensorsFileLoader = fastsafetensors.placeholder_attr( + "SafeTensorsFileLoader") + SingleGroup = fastsafetensors.placeholder_attr("SingleGroup") + logger = init_logger(__name__) # use system-level temp directory for file locks, so that multiple users @@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator( yield from streamer.get_tensors() +def fastsafetensors_weights_iterator( + hf_weights_files: List[str], + use_tqdm_on_load: bool, +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files + using fastsafetensor library.""" + if torch.distributed.is_initialized(): + pg = torch.distributed.group.WORLD + else: + pg = SingleGroup() + + device = torch.device(f'cuda:{pg.rank()}') + weight_files_sub_lists = [ + hf_weights_files[i:i + pg.size()] + for i in range(0, len(hf_weights_files), pg.size()) + ] + + for f_list in tqdm( + weight_files_sub_lists, + desc="Loading safetensors using Fastsafetensor loader", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + loader = SafeTensorsFileLoader(pg, device) + rank_file_map = {i: [f] for i, f in enumerate(f_list)} + loader.add_filenames(rank_file_map) + try: + fb = loader.copy_files_to_device() + try: + keys = list(fb.key_to_rank_lidx.keys()) + for k in keys: + t = fb.get_tensor(k) + yield k, t + finally: + fb.close() + finally: + loader.close() + + def pt_weights_iterator( hf_weights_files: List[str], use_tqdm_on_load: bool,