[Core] Integrate fastsafetensors
loader for loading model weights (#10647)
Signed-off-by: Manish Sethi <Manish.sethi1@ibm.com>
This commit is contained in:
parent
9606d572ed
commit
761702fd19
5
docs/source/models/extensions/fastsafetensor.md
Normal file
5
docs/source/models/extensions/fastsafetensor.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
Loading Model weights with fastsafetensors
|
||||||
|
===================================================================
|
||||||
|
|
||||||
|
Using fastsafetensor library enables loading model weights to GPU memory by leveraging GPU direct storage. See https://github.com/foundation-model-stack/fastsafetensors for more details.
|
||||||
|
For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``
|
@ -5,4 +5,5 @@
|
|||||||
|
|
||||||
runai_model_streamer
|
runai_model_streamer
|
||||||
tensorizer
|
tensorizer
|
||||||
|
fastsafetensor
|
||||||
:::
|
:::
|
||||||
|
@ -41,3 +41,4 @@ tritonclient==2.51.0
|
|||||||
numpy < 2.0.0
|
numpy < 2.0.0
|
||||||
runai-model-streamer==0.11.0
|
runai-model-streamer==0.11.0
|
||||||
runai-model-streamer-s3==0.11.0
|
runai-model-streamer-s3==0.11.0
|
||||||
|
fastsafetensors>=0.1.10
|
||||||
|
@ -67,6 +67,7 @@ click==8.1.7
|
|||||||
# jiwer
|
# jiwer
|
||||||
# nltk
|
# nltk
|
||||||
# ray
|
# ray
|
||||||
|
# typer
|
||||||
colorama==0.4.6
|
colorama==0.4.6
|
||||||
# via
|
# via
|
||||||
# awscli
|
# awscli
|
||||||
@ -122,6 +123,8 @@ fastparquet==2024.11.0
|
|||||||
# via genai-perf
|
# via genai-perf
|
||||||
fastrlock==0.8.2
|
fastrlock==0.8.2
|
||||||
# via cupy-cuda12x
|
# via cupy-cuda12x
|
||||||
|
fastsafetensors==0.1.10
|
||||||
|
# via -r requirements/test.in
|
||||||
filelock==3.16.1
|
filelock==3.16.1
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
@ -505,7 +508,9 @@ requests==2.32.3
|
|||||||
responses==0.25.3
|
responses==0.25.3
|
||||||
# via genai-perf
|
# via genai-perf
|
||||||
rich==13.9.4
|
rich==13.9.4
|
||||||
# via genai-perf
|
# via
|
||||||
|
# genai-perf
|
||||||
|
# typer
|
||||||
rouge-score==0.1.2
|
rouge-score==0.1.2
|
||||||
# via lm-eval
|
# via lm-eval
|
||||||
rpds-py==0.20.1
|
rpds-py==0.20.1
|
||||||
@ -550,6 +555,8 @@ setuptools==75.8.0
|
|||||||
# via
|
# via
|
||||||
# pytablewriter
|
# pytablewriter
|
||||||
# torch
|
# torch
|
||||||
|
shellingham==1.5.4
|
||||||
|
# via typer
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
# via
|
# via
|
||||||
# python-dateutil
|
# python-dateutil
|
||||||
@ -600,6 +607,7 @@ torch==2.6.0
|
|||||||
# accelerate
|
# accelerate
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
# encodec
|
# encodec
|
||||||
|
# fastsafetensors
|
||||||
# lm-eval
|
# lm-eval
|
||||||
# peft
|
# peft
|
||||||
# runai-model-streamer
|
# runai-model-streamer
|
||||||
@ -654,6 +662,8 @@ typepy==1.3.2
|
|||||||
# dataproperty
|
# dataproperty
|
||||||
# pytablewriter
|
# pytablewriter
|
||||||
# tabledata
|
# tabledata
|
||||||
|
typer==0.15.2
|
||||||
|
# via fastsafetensors
|
||||||
typing-extensions==4.12.2
|
typing-extensions==4.12.2
|
||||||
# via
|
# via
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
@ -663,6 +673,7 @@ typing-extensions==4.12.2
|
|||||||
# pydantic
|
# pydantic
|
||||||
# pydantic-core
|
# pydantic-core
|
||||||
# torch
|
# torch
|
||||||
|
# typer
|
||||||
tzdata==2024.2
|
tzdata==2024.2
|
||||||
# via pandas
|
# via pandas
|
||||||
urllib3==2.2.3
|
urllib3==2.2.3
|
||||||
|
1
setup.py
1
setup.py
@ -680,6 +680,7 @@ setup(
|
|||||||
install_requires=get_requirements(),
|
install_requires=get_requirements(),
|
||||||
extras_require={
|
extras_require={
|
||||||
"tensorizer": ["tensorizer>=2.9.0"],
|
"tensorizer": ["tensorizer>=2.9.0"],
|
||||||
|
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||||
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
||||||
"audio": ["librosa", "soundfile"], # Required for audio processing
|
"audio": ["librosa", "soundfile"], # Required for audio processing
|
||||||
"video": ["decord"] # Required for video processing
|
"video": ["decord"] # Required for video processing
|
||||||
|
0
tests/fastsafetensors_loader/__init__.py
Normal file
0
tests/fastsafetensors_loader/__init__.py
Normal file
22
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
Normal file
22
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.config import LoadFormat
|
||||||
|
|
||||||
|
test_model = "openai-community/gpt2"
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_loader_download_files(vllm_runner):
|
||||||
|
with vllm_runner(test_model,
|
||||||
|
load_format=LoadFormat.FASTSAFETENSORS) as llm:
|
||||||
|
deserialized_outputs = llm.generate(prompts, sampling_params)
|
||||||
|
assert deserialized_outputs
|
46
tests/fastsafetensors_loader/test_weight_utils.py
Normal file
46
tests/fastsafetensors_loader/test_weight_utils.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import huggingface_hub.constants
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
download_weights_from_hf, fastsafetensors_weights_iterator,
|
||||||
|
safetensors_weights_iterator)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fastsafetensors_model_loader():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
||||||
|
download_weights_from_hf("openai-community/gpt2",
|
||||||
|
allow_patterns=["*.safetensors"],
|
||||||
|
cache_dir=tmpdir)
|
||||||
|
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
|
||||||
|
assert len(safetensors) > 0
|
||||||
|
|
||||||
|
fastsafetensors_tensors = {}
|
||||||
|
hf_safetensors_tensors = {}
|
||||||
|
|
||||||
|
for name, tensor in fastsafetensors_weights_iterator(
|
||||||
|
safetensors, True):
|
||||||
|
fastsafetensors_tensors[name] = tensor
|
||||||
|
|
||||||
|
for name, tensor in safetensors_weights_iterator(safetensors, True):
|
||||||
|
hf_safetensors_tensors[name] = tensor
|
||||||
|
|
||||||
|
assert len(fastsafetensors_tensors) == len(hf_safetensors_tensors)
|
||||||
|
|
||||||
|
for name, fastsafetensors_tensor in fastsafetensors_tensors.items():
|
||||||
|
fastsafetensors_tensor = fastsafetensors_tensor.to('cpu')
|
||||||
|
assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[
|
||||||
|
name].dtype
|
||||||
|
assert fastsafetensors_tensor.shape == hf_safetensors_tensors[
|
||||||
|
name].shape
|
||||||
|
assert torch.all(
|
||||||
|
fastsafetensors_tensor.eq(hf_safetensors_tensors[name]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_fastsafetensors_model_loader()
|
@ -1277,6 +1277,7 @@ class LoadFormat(str, enum.Enum):
|
|||||||
BITSANDBYTES = "bitsandbytes"
|
BITSANDBYTES = "bitsandbytes"
|
||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
RUNAI_STREAMER = "runai_streamer"
|
RUNAI_STREAMER = "runai_streamer"
|
||||||
|
FASTSAFETENSORS = "fastsafetensors"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -49,9 +49,10 @@ from vllm.model_executor.model_loader.utils import (ParamMapping,
|
|||||||
set_default_torch_dtype)
|
set_default_torch_dtype)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
||||||
get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator,
|
filter_files_not_needed_for_inference, get_gguf_extra_tensor_names,
|
||||||
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
get_lock, gguf_quant_weights_iterator, initialize_dummy_weights,
|
||||||
|
np_cache_weights_iterator, pt_weights_iterator,
|
||||||
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -275,7 +276,8 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
# Some quantized models use .pt files for storing the weights.
|
# Some quantized models use .pt files for storing the weights.
|
||||||
if load_format == LoadFormat.AUTO:
|
if load_format == LoadFormat.AUTO:
|
||||||
allow_patterns = ["*.safetensors", "*.bin"]
|
allow_patterns = ["*.safetensors", "*.bin"]
|
||||||
elif load_format == LoadFormat.SAFETENSORS:
|
elif (load_format == LoadFormat.SAFETENSORS
|
||||||
|
or load_format == LoadFormat.FASTSAFETENSORS):
|
||||||
use_safetensors = True
|
use_safetensors = True
|
||||||
allow_patterns = ["*.safetensors"]
|
allow_patterns = ["*.safetensors"]
|
||||||
elif load_format == LoadFormat.MISTRAL:
|
elif load_format == LoadFormat.MISTRAL:
|
||||||
@ -357,10 +359,16 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
self.load_config.use_tqdm_on_load,
|
self.load_config.use_tqdm_on_load,
|
||||||
)
|
)
|
||||||
elif use_safetensors:
|
elif use_safetensors:
|
||||||
weights_iterator = safetensors_weights_iterator(
|
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
|
||||||
hf_weights_files,
|
weights_iterator = fastsafetensors_weights_iterator(
|
||||||
self.load_config.use_tqdm_on_load,
|
hf_weights_files,
|
||||||
)
|
self.load_config.use_tqdm_on_load,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weights_iterator = safetensors_weights_iterator(
|
||||||
|
hf_weights_files,
|
||||||
|
self.load_config.use_tqdm_on_load,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weights_iterator = pt_weights_iterator(
|
weights_iterator = pt_weights_iterator(
|
||||||
hf_weights_files,
|
hf_weights_files,
|
||||||
|
@ -38,6 +38,14 @@ except (ImportError, OSError):
|
|||||||
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
|
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
|
||||||
"SafetensorsStreamer")
|
"SafetensorsStreamer")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
|
||||||
|
except ImportError:
|
||||||
|
fastsafetensors = PlaceholderModule("fastsafetensors")
|
||||||
|
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
|
||||||
|
"SafeTensorsFileLoader")
|
||||||
|
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# use system-level temp directory for file locks, so that multiple users
|
# use system-level temp directory for file locks, so that multiple users
|
||||||
@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator(
|
|||||||
yield from streamer.get_tensors()
|
yield from streamer.get_tensors()
|
||||||
|
|
||||||
|
|
||||||
|
def fastsafetensors_weights_iterator(
|
||||||
|
hf_weights_files: List[str],
|
||||||
|
use_tqdm_on_load: bool,
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
"""Iterate over the weights in the model safetensor files
|
||||||
|
using fastsafetensor library."""
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
pg = torch.distributed.group.WORLD
|
||||||
|
else:
|
||||||
|
pg = SingleGroup()
|
||||||
|
|
||||||
|
device = torch.device(f'cuda:{pg.rank()}')
|
||||||
|
weight_files_sub_lists = [
|
||||||
|
hf_weights_files[i:i + pg.size()]
|
||||||
|
for i in range(0, len(hf_weights_files), pg.size())
|
||||||
|
]
|
||||||
|
|
||||||
|
for f_list in tqdm(
|
||||||
|
weight_files_sub_lists,
|
||||||
|
desc="Loading safetensors using Fastsafetensor loader",
|
||||||
|
disable=not enable_tqdm(use_tqdm_on_load),
|
||||||
|
bar_format=_BAR_FORMAT,
|
||||||
|
):
|
||||||
|
loader = SafeTensorsFileLoader(pg, device)
|
||||||
|
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
|
||||||
|
loader.add_filenames(rank_file_map)
|
||||||
|
try:
|
||||||
|
fb = loader.copy_files_to_device()
|
||||||
|
try:
|
||||||
|
keys = list(fb.key_to_rank_lidx.keys())
|
||||||
|
for k in keys:
|
||||||
|
t = fb.get_tensor(k)
|
||||||
|
yield k, t
|
||||||
|
finally:
|
||||||
|
fb.close()
|
||||||
|
finally:
|
||||||
|
loader.close()
|
||||||
|
|
||||||
|
|
||||||
def pt_weights_iterator(
|
def pt_weights_iterator(
|
||||||
hf_weights_files: List[str],
|
hf_weights_files: List[str],
|
||||||
use_tqdm_on_load: bool,
|
use_tqdm_on_load: bool,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user