[Core] Loading model from S3 using RunAI Model Streamer as optional loader (#10192)
Signed-off-by: OmerD <omer@run.ai>
This commit is contained in:
parent
7c7aa37c69
commit
995f56236b
@ -240,9 +240,9 @@ FROM vllm-base AS vllm-openai
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10'; \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
else \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10'; \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
@ -88,6 +88,7 @@ Documentation
|
||||
serving/metrics
|
||||
serving/integrations
|
||||
serving/tensorizer
|
||||
serving/runai_model_streamer
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
53
docs/source/serving/runai_model_streamer.rst
Normal file
53
docs/source/serving/runai_model_streamer.rst
Normal file
@ -0,0 +1,53 @@
|
||||
.. _runai_model_streamer:
|
||||
|
||||
Loading Models with Run:ai Model Streamer
|
||||
=========================================
|
||||
Run:ai Model Streamer is a library to read tensors in concurrency, while streaming it to GPU memory.
|
||||
Further reading can be found in `Run:ai Model Streamer Documentation <https://github.com/run-ai/runai-model-streamer/blob/master/docs/README.md>`_.
|
||||
|
||||
vLLM supports loading weights in Safetensors format using the Run:ai Model Streamer.
|
||||
You first need to install vLLM RunAI optional dependency:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip3 install vllm[runai]
|
||||
|
||||
To run it as an OpenAI-compatible server, add the `--load-format runai_streamer` flag:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer
|
||||
|
||||
To run model from AWS S3 object store run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer
|
||||
|
||||
|
||||
To run model from a S3 compatible object store run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING=0 AWS_EC2_METADATA_DISABLED=true AWS_ENDPOINT_URL=https://storage.googleapis.com vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer
|
||||
|
||||
Tunable parameters
|
||||
------------------
|
||||
You can tune parameters using `--model-loader-extra-config`:
|
||||
|
||||
You can tune `concurrency` that controls the level of concurrency and number of OS threads reading tensors from the file to the CPU buffer.
|
||||
For reading from S3, it will be the number of client instances the host is opening to the S3 server.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"concurrency":16}'
|
||||
|
||||
You can controls the size of the CPU Memory buffer to which tensors are read from the file, and limit this size.
|
||||
You can read further about CPU buffer memory limiting `here <https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md#runai_streamer_memory_limit>`_.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"memory_limit":5368709120}'
|
||||
|
||||
.. note::
|
||||
For further instructions about tunable parameters and additional parameters configurable through environment variables, read the `Environment Variables Documentation <https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md>`_.
|
1
setup.py
1
setup.py
@ -630,6 +630,7 @@ setup(
|
||||
ext_modules=ext_modules,
|
||||
extras_require={
|
||||
"tensorizer": ["tensorizer>=2.9.0"],
|
||||
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
||||
"audio": ["librosa", "soundfile"], # Required for audio processing
|
||||
"video": ["decord"] # Required for video processing
|
||||
},
|
||||
|
0
tests/runai_model_streamer/__init__.py
Normal file
0
tests/runai_model_streamer/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import LoadConfig, LoadFormat
|
||||
from vllm.model_executor.model_loader.loader import (RunaiModelStreamerLoader,
|
||||
get_model_loader)
|
||||
|
||||
test_model = "openai-community/gpt2"
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||||
|
||||
|
||||
def get_runai_model_loader():
|
||||
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
|
||||
return get_model_loader(load_config)
|
||||
|
||||
|
||||
def test_get_model_loader_with_runai_flag():
|
||||
model_loader = get_runai_model_loader()
|
||||
assert isinstance(model_loader, RunaiModelStreamerLoader)
|
||||
|
||||
|
||||
def test_runai_model_loader_download_files(vllm_runner):
|
||||
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
|
||||
deserialized_outputs = llm.generate(prompts, sampling_params)
|
||||
assert deserialized_outputs
|
39
tests/runai_model_streamer/test_weight_utils.py
Normal file
39
tests/runai_model_streamer/test_weight_utils.py
Normal file
@ -0,0 +1,39 @@
|
||||
import glob
|
||||
import tempfile
|
||||
|
||||
import huggingface_hub.constants
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf, runai_safetensors_weights_iterator,
|
||||
safetensors_weights_iterator)
|
||||
|
||||
|
||||
def test_runai_model_loader():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
||||
download_weights_from_hf("openai-community/gpt2",
|
||||
allow_patterns=["*.safetensors"],
|
||||
cache_dir=tmpdir)
|
||||
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
|
||||
assert len(safetensors) > 0
|
||||
|
||||
runai_model_streamer_tensors = {}
|
||||
hf_safetensors_tensors = {}
|
||||
|
||||
for name, tensor in runai_safetensors_weights_iterator(safetensors):
|
||||
runai_model_streamer_tensors[name] = tensor
|
||||
|
||||
for name, tensor in safetensors_weights_iterator(safetensors):
|
||||
hf_safetensors_tensors[name] = tensor
|
||||
|
||||
assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors)
|
||||
|
||||
for name, runai_tensor in runai_model_streamer_tensors.items():
|
||||
assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype
|
||||
assert runai_tensor.shape == hf_safetensors_tensors[name].shape
|
||||
assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_runai_model_loader()
|
@ -29,6 +29,7 @@ from vllm.transformers_utils.config import (
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||
try_get_generation_config, uses_mrope)
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, print_warning_once, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
@ -256,6 +257,8 @@ class ModelConfig:
|
||||
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
|
||||
|
||||
# The tokenizer version is consistent with the model version by default.
|
||||
if tokenizer_revision is None:
|
||||
self.tokenizer_revision = revision
|
||||
@ -357,6 +360,39 @@ class ModelConfig:
|
||||
self._verify_cuda_graph()
|
||||
self._verify_bnb_config()
|
||||
|
||||
def maybe_pull_model_tokenizer_for_s3(self, model: str,
|
||||
tokenizer: str) -> None:
|
||||
"""
|
||||
Pull the model config or tokenizer to a temporary
|
||||
directory in case of S3.
|
||||
|
||||
Args:
|
||||
model: The model name or path.
|
||||
tokenizer: The tokenizer name or path.
|
||||
|
||||
"""
|
||||
if is_s3(model) or is_s3(tokenizer):
|
||||
try:
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency "
|
||||
"to use the S3 capabilities. "
|
||||
"You can install it with: pip install vllm[runai]"
|
||||
) from err
|
||||
|
||||
if is_s3(model):
|
||||
self.s3_model = S3Model()
|
||||
self.s3_model.pull_files(model, allow_pattern=["*config.json"])
|
||||
self.model_weights = self.model
|
||||
self.model = self.s3_model.dir
|
||||
|
||||
if is_s3(tokenizer):
|
||||
self.s3_tokenizer = S3Model()
|
||||
self.s3_tokenizer.pull_files(
|
||||
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||
self.tokenizer = self.s3_tokenizer.dir
|
||||
|
||||
def _init_multimodal_config(
|
||||
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
|
||||
) -> Optional["MultiModalConfig"]:
|
||||
@ -1099,6 +1135,7 @@ class LoadFormat(str, enum.Enum):
|
||||
GGUF = "gguf"
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
RUNAI_STREAMER = "runai_streamer"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -316,6 +316,8 @@ class EngineArgs:
|
||||
'* "tensorizer" will load the weights using tensorizer from '
|
||||
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
|
||||
'section for more information.\n'
|
||||
'* "runai_streamer" will load the Safetensors weights using Run:ai'
|
||||
'Model Streamer \n'
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
'quantization.\n')
|
||||
parser.add_argument(
|
||||
|
@ -45,9 +45,10 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
|
||||
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
||||
safetensors_weights_iterator)
|
||||
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -1234,6 +1235,118 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
return model
|
||||
|
||||
|
||||
class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load safetensors
|
||||
files from local FS or S3 bucket.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if ("concurrency" in extra_config
|
||||
and isinstance(extra_config.get("concurrency"), int)):
|
||||
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||
extra_config.get("concurrency"))
|
||||
|
||||
if ("memory_limit" in extra_config
|
||||
and isinstance(extra_config.get("memory_limit"), int)):
|
||||
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||
extra_config.get("memory_limit"))
|
||||
|
||||
runai_streamer_s3_endpoint = os.getenv(
|
||||
'RUNAI_STREAMER_S3_ENDPOINT')
|
||||
aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
|
||||
if (runai_streamer_s3_endpoint is None
|
||||
and aws_endpoint_url is not None):
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> List[str]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
is_s3_path = is_s3(model_name_or_path)
|
||||
if is_s3_path:
|
||||
try:
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency "
|
||||
"to use the S3 capabilities. "
|
||||
"You can install it with: pip install vllm[runai]"
|
||||
) from err
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
hf_folder = (model_name_or_path if
|
||||
(is_local or is_s3_path) else download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[safetensors_pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
))
|
||||
|
||||
if is_s3_path:
|
||||
hf_weights_files = s3_glob(path=hf_folder,
|
||||
allow_pattern=[safetensors_pattern])
|
||||
else:
|
||||
hf_weights_files = glob.glob(
|
||||
os.path.join(hf_folder, safetensors_pattern))
|
||||
|
||||
if not is_local and not is_s3_path:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, index_file, self.load_config.download_dir,
|
||||
revision)
|
||||
|
||||
if not hf_weights_files:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any safetensors model weights with "
|
||||
f"`{model_name_or_path}`")
|
||||
|
||||
return hf_weights_files
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_or_path: str,
|
||||
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(hf_weights_files)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download model if necessary"""
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
"""Perform streaming of the model to destination"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_weights,
|
||||
model_config.revision))
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
@ -1255,4 +1368,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
if load_config.load_format == LoadFormat.GGUF:
|
||||
return GGUFModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
|
||||
return RunaiModelStreamerLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
@ -410,6 +410,30 @@ def safetensors_weights_iterator(
|
||||
yield name, param
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency."
|
||||
"You can install it with: pip install vllm[runai]") from err
|
||||
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
with SafetensorsStreamer() as streamer:
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors using Runai Model Streamer",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
streamer.stream_file(st_file)
|
||||
yield from streamer.get_tensors()
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
|
146
vllm/transformers_utils/s3_utils.py
Normal file
146
vllm/transformers_utils/s3_utils.py
Normal file
@ -0,0 +1,146 @@
|
||||
import fnmatch
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
|
||||
|
||||
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path for path in paths if any(
|
||||
fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path for path in paths
|
||||
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def glob(s3=None,
|
||||
path: str = "",
|
||||
allow_pattern: Optional[list[str]] = None) -> list[str]:
|
||||
"""
|
||||
List full file names from S3 path and filter by allow pattern.
|
||||
|
||||
Args:
|
||||
s3: S3 client to use.
|
||||
path: The S3 path to list from.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
|
||||
Returns:
|
||||
list[str]: List of full S3 paths allowed by the pattern
|
||||
"""
|
||||
if s3 is None:
|
||||
s3 = boto3.client("s3")
|
||||
bucket_name, _, paths = list_files(s3,
|
||||
path=path,
|
||||
allow_pattern=allow_pattern)
|
||||
return [f"s3://{bucket_name}/{path}" for path in paths]
|
||||
|
||||
|
||||
def list_files(
|
||||
s3,
|
||||
path: str,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None
|
||||
) -> tuple[str, str, list[str]]:
|
||||
"""
|
||||
List files from S3 path and filter by pattern.
|
||||
|
||||
Args:
|
||||
s3: S3 client to use.
|
||||
path: The S3 path to list from.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
ignore_pattern: A list of patterns of which files not to pull.
|
||||
|
||||
Returns:
|
||||
tuple[str, str, list[str]]: A tuple where:
|
||||
- The first element is the bucket name
|
||||
- The second element is string represent the bucket
|
||||
and the prefix as a dir like string
|
||||
- The third element is a list of files allowed or
|
||||
disallowed by pattern
|
||||
"""
|
||||
parts = path.removeprefix('s3://').split('/')
|
||||
prefix = '/'.join(parts[1:])
|
||||
bucket_name = parts[0]
|
||||
|
||||
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
||||
paths = [obj['Key'] for obj in objects.get('Contents', [])]
|
||||
|
||||
paths = _filter_ignore(paths, ["*/"])
|
||||
if allow_pattern is not None:
|
||||
paths = _filter_allow(paths, allow_pattern)
|
||||
|
||||
if ignore_pattern is not None:
|
||||
paths = _filter_ignore(paths, ignore_pattern)
|
||||
|
||||
return bucket_name, prefix, paths
|
||||
|
||||
|
||||
class S3Model:
|
||||
"""
|
||||
A class representing a S3 model mirrored into a temporary directory.
|
||||
|
||||
Attributes:
|
||||
s3: S3 client.
|
||||
dir: The temporary created directory.
|
||||
|
||||
Methods:
|
||||
pull_files(): Pull model from S3 to the temporary directory.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.s3 = boto3.client('s3')
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
existing_handler = signal.getsignal(sig)
|
||||
signal.signal(sig, self._close_by_signal(existing_handler))
|
||||
self.dir = tempfile.mkdtemp()
|
||||
|
||||
def __del__(self):
|
||||
self._close()
|
||||
|
||||
def _close(self) -> None:
|
||||
if os.path.exists(self.dir):
|
||||
shutil.rmtree(self.dir)
|
||||
|
||||
def _close_by_signal(self, existing_handler=None):
|
||||
|
||||
def new_handler(signum, frame):
|
||||
self._close()
|
||||
if existing_handler:
|
||||
existing_handler(signum, frame)
|
||||
|
||||
return new_handler
|
||||
|
||||
def pull_files(self,
|
||||
s3_model_path: str = "",
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Pull files from S3 storage into the temporary directory.
|
||||
|
||||
Args:
|
||||
s3_model_path: The S3 path of the model.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
ignore_pattern: A list of patterns of which files not to pull.
|
||||
|
||||
"""
|
||||
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
|
||||
allow_pattern,
|
||||
ignore_pattern)
|
||||
if len(files) == 0:
|
||||
return
|
||||
|
||||
for file in files:
|
||||
destination_file = self.dir + file.removeprefix(base_dir)
|
||||
local_dir = Path(destination_file).parent
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
self.s3.download_file(bucket_name, file, destination_file)
|
@ -3,6 +3,10 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
def is_s3(model_or_path: str) -> bool:
|
||||
return model_or_path.lower().startswith('s3://')
|
||||
|
||||
|
||||
def check_gguf_file(model: Union[str, PathLike]) -> bool:
|
||||
"""Check if the file is a GGUF model."""
|
||||
model = Path(model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user