
Signed-off-by: <> Co-authored-by: EC2 Default User <ec2-user@ip-172-31-20-117.us-west-2.compute.internal>
42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import glob
|
|
import tempfile
|
|
|
|
import huggingface_hub.constants
|
|
import torch
|
|
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
download_weights_from_hf, runai_safetensors_weights_iterator,
|
|
safetensors_weights_iterator)
|
|
|
|
|
|
def test_runai_model_loader():
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
|
download_weights_from_hf("openai-community/gpt2",
|
|
allow_patterns=["*.safetensors"],
|
|
cache_dir=tmpdir)
|
|
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
|
|
assert len(safetensors) > 0
|
|
|
|
runai_model_streamer_tensors = {}
|
|
hf_safetensors_tensors = {}
|
|
|
|
for name, tensor in runai_safetensors_weights_iterator(safetensors):
|
|
runai_model_streamer_tensors[name] = tensor
|
|
|
|
for name, tensor in safetensors_weights_iterator(safetensors):
|
|
hf_safetensors_tensors[name] = tensor
|
|
|
|
assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors)
|
|
|
|
for name, runai_tensor in runai_model_streamer_tensors.items():
|
|
assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype
|
|
assert runai_tensor.shape == hf_safetensors_tensors[name].shape
|
|
assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name]))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_runai_model_loader()
|