[Frontend] [Core] feat: Add model loading using tensorizer
(#3476)
This commit is contained in:
parent
989ae2538d
commit
711a000255
@ -91,6 +91,9 @@ steps:
|
|||||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
|
- label: Tensorizer Test
|
||||||
|
command: apt-get install curl libsodium23 && pytest -v -s tensorizer
|
||||||
|
|
||||||
- label: Metrics Test
|
- label: Metrics Test
|
||||||
command: pytest -v -s metrics
|
command: pytest -v -s metrics
|
||||||
|
|
||||||
|
@ -83,6 +83,7 @@ autodoc_mock_imports = [
|
|||||||
"vllm._C",
|
"vllm._C",
|
||||||
"numpy",
|
"numpy",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
"tensorizer",
|
||||||
]
|
]
|
||||||
|
|
||||||
for mock_target in autodoc_mock_imports:
|
for mock_target in autodoc_mock_imports:
|
||||||
|
@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM:
|
|||||||
|
|
||||||
Directory to download and load the weights, default to the default cache dir of huggingface.
|
Directory to download and load the weights, default to the default cache dir of huggingface.
|
||||||
|
|
||||||
.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
|
.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
|
||||||
|
|
||||||
The format of the model weights to load.
|
The format of the model weights to load.
|
||||||
|
|
||||||
@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM:
|
|||||||
* "safetensors" will load the weights in the safetensors format.
|
* "safetensors" will load the weights in the safetensors format.
|
||||||
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
|
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
|
||||||
* "dummy" will initialize the weights with random values, mainly for profiling.
|
* "dummy" will initialize the weights with random values, mainly for profiling.
|
||||||
|
* "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`.
|
||||||
|
|
||||||
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
|
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
|
||||||
|
|
||||||
|
254
examples/tensorize_vllm_model.py
Normal file
254
examples/tensorize_vllm_model.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from functools import partial
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
||||||
|
TensorSerializer, stream_io)
|
||||||
|
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||||
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.distributed import initialize_model_parallel
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this docstring
|
||||||
|
# yapf: disable
|
||||||
|
"""
|
||||||
|
tensorize_vllm_model.py is a script that can be used to serialize and
|
||||||
|
deserialize vLLM models. These models can be loaded using tensorizer directly
|
||||||
|
to the GPU extremely quickly. Tensor encryption and decryption is also
|
||||||
|
supported, although libsodium must be installed to use it. Install
|
||||||
|
vllm with tensorizer support using `pip install vllm[tensorizer]`.
|
||||||
|
|
||||||
|
To serialize a model, you can run something like this:
|
||||||
|
|
||||||
|
python tensorize_vllm_model.py \
|
||||||
|
--model EleutherAI/gpt-j-6B \
|
||||||
|
--dtype float16 \
|
||||||
|
serialize \
|
||||||
|
--serialized-directory s3://my-bucket/ \
|
||||||
|
--suffix vllm
|
||||||
|
|
||||||
|
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
||||||
|
and saves it to your S3 bucket. A local directory can also be used.
|
||||||
|
|
||||||
|
You can also encrypt the model weights with a randomly-generated key by
|
||||||
|
providing a `--keyfile` argument.
|
||||||
|
|
||||||
|
To deserialize a model, you can run something like this:
|
||||||
|
|
||||||
|
python tensorize_vllm_model.py \
|
||||||
|
--model EleutherAI/gpt-j-6B \
|
||||||
|
--dtype float16 \
|
||||||
|
deserialize \
|
||||||
|
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
|
||||||
|
|
||||||
|
Which downloads the model tensors from your S3 bucket and deserializes them.
|
||||||
|
To provide S3 credentials, you can provide `--s3-access-key-id` and
|
||||||
|
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
|
||||||
|
the OpenAI entrypoint, as arguments for LLM(), or as environment variables
|
||||||
|
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
|
||||||
|
|
||||||
|
|
||||||
|
You can also provide a `--keyfile` argument to decrypt the model weights if
|
||||||
|
they were serialized with encryption.
|
||||||
|
|
||||||
|
For more information on the available arguments, run
|
||||||
|
`python tensorize_vllm_model.py --help`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="An example script that can be used to serialize and "
|
||||||
|
"deserialize vLLM models. These models "
|
||||||
|
"can be loaded using tensorizer directly to the GPU "
|
||||||
|
"extremely quickly. Tensor encryption and decryption is "
|
||||||
|
"also supported, although libsodium must be installed to "
|
||||||
|
"use it.")
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
subparsers = parser.add_subparsers(dest='command')
|
||||||
|
|
||||||
|
serialize_parser = subparsers.add_parser(
|
||||||
|
'serialize', help="Serialize a model to `--serialized-directory`")
|
||||||
|
|
||||||
|
serialize_parser.add_argument(
|
||||||
|
"--suffix",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help=(
|
||||||
|
"The suffix to append to the serialized model directory, which is "
|
||||||
|
"used to construct the location of the serialized model tensors, "
|
||||||
|
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
|
||||||
|
"`--suffix` is `v1`, the serialized model tensors will be "
|
||||||
|
"saved to "
|
||||||
|
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
|
||||||
|
"If none is provided, a random UUID will be used."))
|
||||||
|
serialize_parser.add_argument(
|
||||||
|
"--serialized-directory",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The directory to serialize the model to. "
|
||||||
|
"This can be a local directory or S3 URI. The path to where the "
|
||||||
|
"tensors are saved is a combination of the supplied `dir` and model "
|
||||||
|
"reference ID. For instance, if `dir` is the serialized directory, "
|
||||||
|
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
|
||||||
|
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
|
||||||
|
"where `suffix` is given by `--suffix` or a random UUID if not "
|
||||||
|
"provided.")
|
||||||
|
|
||||||
|
serialize_parser.add_argument(
|
||||||
|
"--keyfile",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help=("Encrypt the model weights with a randomly-generated binary key,"
|
||||||
|
" and save the key at this path"))
|
||||||
|
|
||||||
|
deserialize_parser = subparsers.add_parser(
|
||||||
|
'deserialize',
|
||||||
|
help=("Deserialize a model from `--path-to-tensors`"
|
||||||
|
" to verify it can be loaded and used."))
|
||||||
|
|
||||||
|
deserialize_parser.add_argument(
|
||||||
|
"--path-to-tensors",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The local path or S3 URI to the model tensors to deserialize. ")
|
||||||
|
|
||||||
|
deserialize_parser.add_argument(
|
||||||
|
"--keyfile",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help=("Path to a binary key to use to decrypt the model weights,"
|
||||||
|
" if the model was serialized with encryption"))
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def make_model_contiguous(model):
|
||||||
|
# Ensure tensors are saved in memory contiguously
|
||||||
|
for param in model.parameters():
|
||||||
|
param.data = param.data.contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||||
|
architectures = getattr(config, "architectures", [])
|
||||||
|
for arch in architectures:
|
||||||
|
model_cls = ModelRegistry.load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return model_cls
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
|
||||||
|
def serialize():
|
||||||
|
|
||||||
|
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||||
|
dataclasses.fields(EngineArgs)}
|
||||||
|
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
model = (engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
|
||||||
|
encryption_params = EncryptionParams.random() if keyfile else None
|
||||||
|
if keyfile:
|
||||||
|
with _write_stream(keyfile) as stream:
|
||||||
|
stream.write(encryption_params.key)
|
||||||
|
|
||||||
|
with _write_stream(model_path) as stream:
|
||||||
|
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||||
|
serializer.write_module(model)
|
||||||
|
serializer.close()
|
||||||
|
|
||||||
|
print("Serialization complete. Model tensors saved to", model_path)
|
||||||
|
if keyfile:
|
||||||
|
print("Key saved to", keyfile)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize():
|
||||||
|
config = AutoConfig.from_pretrained(model_ref)
|
||||||
|
|
||||||
|
with no_init_or_tensor():
|
||||||
|
model_class = _get_vllm_model_architecture(config)
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
before_mem = get_mem_usage()
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
if keyfile:
|
||||||
|
with _read_stream(keyfile) as stream:
|
||||||
|
key = stream.read()
|
||||||
|
decryption_params = DecryptionParams.from_key(key)
|
||||||
|
tensorizer_args.deserializer_params['encryption'] = \
|
||||||
|
decryption_params
|
||||||
|
|
||||||
|
with (_read_stream(model_path)) as stream, TensorDeserializer(
|
||||||
|
stream, **tensorizer_args.deserializer_params) as deserializer:
|
||||||
|
deserializer.load_into_module(model)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
# Brag about how fast we are.
|
||||||
|
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||||
|
duration = end - start
|
||||||
|
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||||
|
after_mem = get_mem_usage()
|
||||||
|
print(
|
||||||
|
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
|
||||||
|
)
|
||||||
|
print(f"Memory usage before: {before_mem}")
|
||||||
|
print(f"Memory usage after: {after_mem}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
|
||||||
|
or None)
|
||||||
|
s3_secret_access_key = (args.s3_secret_access_key
|
||||||
|
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
|
||||||
|
|
||||||
|
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
|
||||||
|
|
||||||
|
_read_stream, _write_stream = (partial(
|
||||||
|
stream_io.open_stream,
|
||||||
|
mode=mode,
|
||||||
|
s3_access_key_id=s3_access_key_id,
|
||||||
|
s3_secret_access_key=s3_secret_access_key,
|
||||||
|
s3_endpoint=s3_endpoint,
|
||||||
|
) for mode in ("rb", "wb+"))
|
||||||
|
|
||||||
|
model_ref = args.model
|
||||||
|
|
||||||
|
model_name = model_ref.split("/")[1]
|
||||||
|
|
||||||
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
|
os.environ["MASTER_PORT"] = "8080"
|
||||||
|
|
||||||
|
torch.distributed.init_process_group(world_size=1, rank=0)
|
||||||
|
initialize_model_parallel()
|
||||||
|
|
||||||
|
keyfile = args.keyfile if args.keyfile else None
|
||||||
|
|
||||||
|
if args.command == "serialize":
|
||||||
|
input_dir = args.serialized_directory.rstrip('/')
|
||||||
|
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||||
|
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||||
|
model_path = f"{base_path}/model.tensors"
|
||||||
|
serialize()
|
||||||
|
elif args.command == "deserialize":
|
||||||
|
tensorizer_args = TensorizerArgs.from_cli_args(args)
|
||||||
|
model_path = args.path_to_tensors
|
||||||
|
deserialize()
|
||||||
|
else:
|
||||||
|
raise ValueError("Either serialize or deserialize must be specified.")
|
@ -14,6 +14,7 @@ types-setuptools
|
|||||||
|
|
||||||
# testing
|
# testing
|
||||||
pytest
|
pytest
|
||||||
|
tensorizer==2.9.0a0
|
||||||
pytest-forked
|
pytest-forked
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
pytest-rerunfailures
|
pytest-rerunfailures
|
||||||
|
3
setup.py
3
setup.py
@ -405,6 +405,9 @@ setup(
|
|||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
install_requires=get_requirements(),
|
install_requires=get_requirements(),
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
|
extras_require={
|
||||||
|
"optional": ["tensorizer==2.9.0a1"],
|
||||||
|
},
|
||||||
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
|
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
)
|
)
|
||||||
|
0
tests/tensorizer/__init__.py
Normal file
0
tests/tensorizer/__init__.py
Normal file
245
tests/tensorizer/tensorize_vllm_model_for_testing.py
Normal file
245
tests/tensorizer/tensorize_vllm_model_for_testing.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from functools import partial
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
||||||
|
TensorSerializer, stream_io)
|
||||||
|
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||||
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.distributed import initialize_model_parallel
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this docstring
|
||||||
|
# yapf: disable
|
||||||
|
"""
|
||||||
|
tensorize_vllm_model.py is a script that can be used to serialize and
|
||||||
|
deserialize vLLM models. These models can be loaded using tensorizer directly
|
||||||
|
to the GPU extremely quickly. Tensor encryption and decryption is also
|
||||||
|
supported, although libsodium must be installed to use it. Install
|
||||||
|
vllm with tensorizer support using `pip install vllm[tensorizer]`.
|
||||||
|
|
||||||
|
To serialize a model, you can run something like this:
|
||||||
|
|
||||||
|
python tensorize_vllm_model.py \
|
||||||
|
--model EleutherAI/gpt-j-6B \
|
||||||
|
--dtype float16 \
|
||||||
|
serialize \
|
||||||
|
--serialized-directory s3://my-bucket/ \
|
||||||
|
--suffix vllm
|
||||||
|
|
||||||
|
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
||||||
|
and saves it to your S3 bucket. A local directory can also be used.
|
||||||
|
|
||||||
|
You can also encrypt the model weights with a randomly-generated key by
|
||||||
|
providing a `--keyfile` argument.
|
||||||
|
|
||||||
|
To deserialize a model, you can run something like this:
|
||||||
|
|
||||||
|
python tensorize_vllm_model.py \
|
||||||
|
--model EleutherAI/gpt-j-6B \
|
||||||
|
--dtype float16 \
|
||||||
|
deserialize \
|
||||||
|
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
|
||||||
|
|
||||||
|
Which downloads the model tensors from your S3 bucket and deserializes them.
|
||||||
|
To provide S3 credentials, you can provide `--s3-access-key-id` and
|
||||||
|
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
|
||||||
|
the OpenAI entrypoint, as arguments for LLM(), or as environment variables
|
||||||
|
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
|
||||||
|
|
||||||
|
|
||||||
|
You can also provide a `--keyfile` argument to decrypt the model weights if
|
||||||
|
they were serialized with encryption.
|
||||||
|
|
||||||
|
For more information on the available arguments, run
|
||||||
|
`python tensorize_vllm_model.py --help`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="An example script that can be used to serialize and "
|
||||||
|
"deserialize vLLM models. These models "
|
||||||
|
"can be loaded using tensorizer directly to the GPU "
|
||||||
|
"extremely quickly. Tensor encryption and decryption is "
|
||||||
|
"also supported, although libsodium must be installed to "
|
||||||
|
"use it.")
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
subparsers = parser.add_subparsers(dest='command')
|
||||||
|
|
||||||
|
serialize_parser = subparsers.add_parser(
|
||||||
|
'serialize', help="Serialize a model to `--serialized-directory`")
|
||||||
|
|
||||||
|
serialize_parser.add_argument(
|
||||||
|
"--suffix",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help=(
|
||||||
|
"The suffix to append to the serialized model directory, which is "
|
||||||
|
"used to construct the location of the serialized model tensors, "
|
||||||
|
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
|
||||||
|
"`--suffix` is `v1`, the serialized model tensors will be "
|
||||||
|
"saved to "
|
||||||
|
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
|
||||||
|
"If none is provided, a random UUID will be used."))
|
||||||
|
serialize_parser.add_argument(
|
||||||
|
"--serialized-directory",
|
||||||
|
type=str,
|
||||||
|
required=True)
|
||||||
|
|
||||||
|
serialize_parser.add_argument(
|
||||||
|
"--keyfile",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help=("Encrypt the model weights with a randomly-generated binary key,"
|
||||||
|
" and save the key at this path"))
|
||||||
|
|
||||||
|
deserialize_parser = subparsers.add_parser(
|
||||||
|
'deserialize',
|
||||||
|
help=("Deserialize a model from `--path-to-tensors`"
|
||||||
|
" to verify it can be loaded and used."))
|
||||||
|
|
||||||
|
deserialize_parser.add_argument(
|
||||||
|
"--path-to-tensors",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The local path or S3 URI to the model tensors to deserialize. ")
|
||||||
|
|
||||||
|
deserialize_parser.add_argument(
|
||||||
|
"--keyfile",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help=("Path to a binary key to use to decrypt the model weights,"
|
||||||
|
" if the model was serialized with encryption"))
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def make_model_contiguous(model):
|
||||||
|
# Ensure tensors are saved in memory contiguously
|
||||||
|
for param in model.parameters():
|
||||||
|
param.data = param.data.contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||||
|
architectures = getattr(config, "architectures", [])
|
||||||
|
for arch in architectures:
|
||||||
|
model_cls = ModelRegistry.load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return model_cls
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
|
||||||
|
def serialize():
|
||||||
|
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||||
|
dataclasses.fields(EngineArgs)}
|
||||||
|
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
model = (engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
|
||||||
|
encryption_params = EncryptionParams.random() if keyfile else None
|
||||||
|
if keyfile:
|
||||||
|
with _write_stream(keyfile) as stream:
|
||||||
|
stream.write(encryption_params.key)
|
||||||
|
|
||||||
|
with _write_stream(model_path) as stream:
|
||||||
|
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||||
|
serializer.write_module(model)
|
||||||
|
serializer.close()
|
||||||
|
|
||||||
|
print("Serialization complete. Model tensors saved to", model_path)
|
||||||
|
if keyfile:
|
||||||
|
print("Key saved to", keyfile)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize():
|
||||||
|
config = AutoConfig.from_pretrained(model_ref)
|
||||||
|
|
||||||
|
with no_init_or_tensor():
|
||||||
|
model_class = _get_vllm_model_architecture(config)
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
before_mem = get_mem_usage()
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
if keyfile:
|
||||||
|
with _read_stream(keyfile) as stream:
|
||||||
|
key = stream.read()
|
||||||
|
decryption_params = DecryptionParams.from_key(key)
|
||||||
|
tensorizer_args.deserializer_params['encryption'] = \
|
||||||
|
decryption_params
|
||||||
|
|
||||||
|
with (_read_stream(model_path)) as stream, TensorDeserializer(
|
||||||
|
stream, **tensorizer_args.deserializer_params) as deserializer:
|
||||||
|
deserializer.load_into_module(model)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
# Brag about how fast we are.
|
||||||
|
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||||
|
duration = end - start
|
||||||
|
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||||
|
after_mem = get_mem_usage()
|
||||||
|
print(
|
||||||
|
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
|
||||||
|
)
|
||||||
|
print(f"Memory usage before: {before_mem}")
|
||||||
|
print(f"Memory usage after: {after_mem}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
|
||||||
|
or None)
|
||||||
|
s3_secret_access_key = (args.s3_secret_access_key
|
||||||
|
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
|
||||||
|
|
||||||
|
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
|
||||||
|
|
||||||
|
_read_stream, _write_stream = (partial(
|
||||||
|
stream_io.open_stream,
|
||||||
|
mode=mode,
|
||||||
|
s3_access_key_id=s3_access_key_id,
|
||||||
|
s3_secret_access_key=s3_secret_access_key,
|
||||||
|
s3_endpoint=s3_endpoint,
|
||||||
|
) for mode in ("rb", "wb+"))
|
||||||
|
|
||||||
|
model_ref = args.model
|
||||||
|
|
||||||
|
model_name = model_ref.split("/")[1]
|
||||||
|
|
||||||
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
|
os.environ["MASTER_PORT"] = "8080"
|
||||||
|
|
||||||
|
torch.distributed.init_process_group(world_size=1, rank=0)
|
||||||
|
initialize_model_parallel()
|
||||||
|
|
||||||
|
keyfile = args.keyfile if args.keyfile else None
|
||||||
|
|
||||||
|
if args.command == "serialize":
|
||||||
|
input_dir = args.serialized_directory.rstrip('/')
|
||||||
|
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||||
|
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||||
|
model_path = f"{base_path}/model.tensors"
|
||||||
|
serialize()
|
||||||
|
elif args.command == "deserialize":
|
||||||
|
tensorizer_args = TensorizerArgs.from_cli_args(args)
|
||||||
|
model_path = args.path_to_tensors
|
||||||
|
deserialize()
|
||||||
|
else:
|
||||||
|
raise ValueError("Either serialize or deserialize must be specified.")
|
302
tests/tensorizer/test_tensorizer.py
Normal file
302
tests/tensorizer/test_tensorizer.py
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
import gc
|
||||||
|
import subprocess
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.entrypoints.test_openai_server import ServerRunner
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.config import TensorizerConfig
|
||||||
|
from vllm.model_executor.tensorizer_loader import (
|
||||||
|
EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer,
|
||||||
|
load_with_tensorizer, open_stream)
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||||||
|
|
||||||
|
model_ref = "facebook/opt-125m"
|
||||||
|
|
||||||
|
|
||||||
|
def is_curl_installed():
|
||||||
|
try:
|
||||||
|
subprocess.check_call(['curl', '--version'])
|
||||||
|
return True
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def tensorizer_config():
|
||||||
|
config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent')
|
||||||
|
def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
||||||
|
mock_linear_method = MagicMock()
|
||||||
|
mock_agent_instance = mock_agent.return_value
|
||||||
|
mock_agent_instance.deserialize.return_value = MagicMock()
|
||||||
|
|
||||||
|
result = load_with_tensorizer(tensorizer_config,
|
||||||
|
linear_method=mock_linear_method)
|
||||||
|
|
||||||
|
mock_agent.assert_called_once_with(tensorizer_config,
|
||||||
|
linear_method=mock_linear_method)
|
||||||
|
mock_agent_instance.deserialize.assert_called_once()
|
||||||
|
assert result == mock_agent_instance.deserialize.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
|
||||||
|
tensorizer_config.vllm_tensorized = True
|
||||||
|
|
||||||
|
result = is_vllm_serialized_tensorizer(tensorizer_config)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
|
||||||
|
tensorizer_config.vllm_tensorized = False
|
||||||
|
|
||||||
|
result = is_vllm_serialized_tensorizer(tensorizer_config)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
|
||||||
|
vllm_model = vllm_runner(model_ref)
|
||||||
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
|
outputs = vllm_model.generate(prompts, sampling_params)
|
||||||
|
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
with open_stream(model_path, "wb+") as stream:
|
||||||
|
serializer = TensorSerializer(stream)
|
||||||
|
serializer.write_module(model)
|
||||||
|
del vllm_model, model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
loaded_vllm_model = vllm_runner(model_ref,
|
||||||
|
load_format="tensorizer",
|
||||||
|
tensorizer_uri=model_path,
|
||||||
|
num_readers=1,
|
||||||
|
vllm_tensorized=True)
|
||||||
|
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Assumes SamplingParams being seeded ensures the outputs are deterministic
|
||||||
|
assert outputs == deserialized_outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||||
|
def test_can_deserialize_s3(vllm_runner):
|
||||||
|
model_ref = "EleutherAI/pythia-1.4b"
|
||||||
|
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
|
||||||
|
|
||||||
|
loaded_hf_model = vllm_runner(
|
||||||
|
model_ref,
|
||||||
|
tensorizer_uri=tensorized_path,
|
||||||
|
load_format="tensorizer",
|
||||||
|
num_readers=1,
|
||||||
|
vllm_tensorized=False,
|
||||||
|
s3_endpoint="object.ord1.coreweave.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
assert deserialized_outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||||
|
def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
||||||
|
vllm_runner, tmp_path):
|
||||||
|
vllm_model = vllm_runner(model_ref)
|
||||||
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
|
key_path = tmp_path / (model_ref + ".key")
|
||||||
|
outputs = vllm_model.generate(prompts, sampling_params)
|
||||||
|
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
|
||||||
|
encryption_params = EncryptionParams.random()
|
||||||
|
with open_stream(model_path, "wb+") as stream:
|
||||||
|
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||||
|
serializer.write_module(model)
|
||||||
|
with open_stream(key_path, "wb+") as stream:
|
||||||
|
stream.write(encryption_params.key)
|
||||||
|
del vllm_model, model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
loaded_vllm_model = vllm_runner(model_ref,
|
||||||
|
tensorizer_uri=model_path,
|
||||||
|
load_format="tensorizer",
|
||||||
|
encryption_keyfile=key_path,
|
||||||
|
num_readers=1,
|
||||||
|
vllm_tensorized=True)
|
||||||
|
|
||||||
|
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Assumes SamplingParams being seeded ensures the outputs are deterministic
|
||||||
|
assert outputs == deserialized_outputs
|
||||||
|
|
||||||
|
|
||||||
|
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
||||||
|
tmp_path):
|
||||||
|
hf_model = hf_runner(model_ref)
|
||||||
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
|
max_tokens = 50
|
||||||
|
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
|
||||||
|
with open_stream(model_path, "wb+") as stream:
|
||||||
|
serializer = TensorSerializer(stream)
|
||||||
|
serializer.write_module(hf_model.model)
|
||||||
|
del hf_model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
loaded_hf_model = vllm_runner(model_ref,
|
||||||
|
tensorizer_uri=model_path,
|
||||||
|
load_format="tensorizer",
|
||||||
|
num_readers=1,
|
||||||
|
vllm_tensorized=False)
|
||||||
|
|
||||||
|
deserialized_outputs = loaded_hf_model.generate_greedy(
|
||||||
|
prompts, max_tokens=max_tokens)
|
||||||
|
|
||||||
|
assert outputs == deserialized_outputs
|
||||||
|
|
||||||
|
|
||||||
|
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from examples.multilora_inference import (create_test_prompts,
|
||||||
|
process_requests)
|
||||||
|
|
||||||
|
model_ref = "meta-llama/Llama-2-7b-hf"
|
||||||
|
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||||
|
test_prompts = create_test_prompts(lora_path)
|
||||||
|
|
||||||
|
# Serialize model before deserializing and binding LoRA adapters
|
||||||
|
vllm_model = vllm_runner(model_ref, )
|
||||||
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
|
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
with open_stream(model_path, "wb+") as stream:
|
||||||
|
serializer = TensorSerializer(stream)
|
||||||
|
serializer.write_module(model)
|
||||||
|
del vllm_model, model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
loaded_vllm_model = vllm_runner(
|
||||||
|
model_ref,
|
||||||
|
tensorizer_uri=model_path,
|
||||||
|
load_format="tensorizer",
|
||||||
|
num_readers=1,
|
||||||
|
vllm_tensorized=True,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=1,
|
||||||
|
max_lora_rank=8,
|
||||||
|
max_cpu_loras=2,
|
||||||
|
max_num_seqs=50,
|
||||||
|
max_model_len=1000,
|
||||||
|
)
|
||||||
|
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
|
||||||
|
|
||||||
|
assert loaded_vllm_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_without_tensorizer_load_format(vllm_runner):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
vllm_runner(model_ref, tensorizer_uri="test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||||
|
def test_tensorize_vllm_model(tmp_path):
|
||||||
|
# Test serialize command
|
||||||
|
serialize_args = [
|
||||||
|
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||||
|
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
|
||||||
|
tmp_path, "--suffix", "tests"
|
||||||
|
]
|
||||||
|
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||||
|
print(result.stdout) # Print the output of the serialize command
|
||||||
|
|
||||||
|
assert result.returncode == 0, (f"Serialize command failed with output:"
|
||||||
|
f"\n{result.stdout}\n{result.stderr}")
|
||||||
|
|
||||||
|
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
|
||||||
|
|
||||||
|
# Test deserialize command
|
||||||
|
deserialize_args = [
|
||||||
|
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||||
|
model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors",
|
||||||
|
path_to_tensors
|
||||||
|
]
|
||||||
|
result = subprocess.run(deserialize_args, capture_output=True, text=True)
|
||||||
|
assert result.returncode == 0, (f"Deserialize command failed with output:"
|
||||||
|
f"\n{result.stdout}\n{result.stderr}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||||
|
def test_openai_apiserver_with_tensorizer(tmp_path):
|
||||||
|
## Serialize model
|
||||||
|
serialize_args = [
|
||||||
|
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||||
|
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
|
||||||
|
tmp_path, "--suffix", "tests"
|
||||||
|
]
|
||||||
|
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||||
|
print(result.stdout) # Print the output of the serialize command
|
||||||
|
|
||||||
|
assert result.returncode == 0, (f"Serialize command failed with output:"
|
||||||
|
f"\n{result.stdout}\n{result.stderr}")
|
||||||
|
|
||||||
|
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
|
||||||
|
|
||||||
|
## Start OpenAI API server
|
||||||
|
openai_args = [
|
||||||
|
"--model", model_ref, "--dtype", "float16", "--load-format",
|
||||||
|
"tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized",
|
||||||
|
"--port", "8000"
|
||||||
|
]
|
||||||
|
|
||||||
|
server = ServerRunner.remote(openai_args)
|
||||||
|
|
||||||
|
print("Server ready.")
|
||||||
|
assert server.ready.remote()
|
||||||
|
|
||||||
|
|
||||||
|
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
vllm_runner(model_ref,
|
||||||
|
load_format="safetensors",
|
||||||
|
tensorizer_uri="test")
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensorizer_with_tp(vllm_runner):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model_ref = "EleutherAI/pythia-1.4b"
|
||||||
|
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
|
||||||
|
|
||||||
|
vllm_runner(
|
||||||
|
model_ref,
|
||||||
|
tensorizer_uri=tensorized_path,
|
||||||
|
load_format="tensorizer",
|
||||||
|
num_readers=1,
|
||||||
|
vllm_tensorized=False,
|
||||||
|
s3_endpoint="object.ord1.coreweave.com",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||||
|
def test_tensorizer_warn_quant(tmp_path):
|
||||||
|
model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
|
||||||
|
serialize_args = [
|
||||||
|
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||||
|
model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
|
||||||
|
"serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
|
||||||
|
]
|
||||||
|
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||||
|
assert 'PerformanceWarning' in result.stderr
|
@ -1,6 +1,8 @@
|
|||||||
import enum
|
import enum
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import typing
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
|
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
|
||||||
|
|
||||||
@ -16,6 +18,8 @@ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
|
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_GB = 1 << 30
|
_GB = 1 << 30
|
||||||
@ -139,13 +143,14 @@ class ModelConfig:
|
|||||||
def _verify_load_format(self) -> None:
|
def _verify_load_format(self) -> None:
|
||||||
load_format = self.load_format.lower()
|
load_format = self.load_format.lower()
|
||||||
supported_load_format = [
|
supported_load_format = [
|
||||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
"auto", "pt", "safetensors", "npcache", "dummy", "tensorizer"
|
||||||
]
|
]
|
||||||
rocm_not_supported_load_format: List[str] = []
|
rocm_not_supported_load_format: List[str] = []
|
||||||
if load_format not in supported_load_format:
|
if load_format not in supported_load_format:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown load format: {self.load_format}. Must be one of "
|
f"Unknown load format: {self.load_format}. Must be one of "
|
||||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
"'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
|
||||||
|
"'dummy'.")
|
||||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||||
rocm_supported_load_format = [
|
rocm_supported_load_format = [
|
||||||
f for f in supported_load_format
|
f for f in supported_load_format
|
||||||
@ -882,6 +887,65 @@ class VisionLanguageConfig:
|
|||||||
f"{[x.name for x in cls.ImageInputType]}.") from e
|
f"{[x.name for x in cls.ImageInputType]}.") from e
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TensorizerConfig:
|
||||||
|
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
||||||
|
str, bytes, os.PathLike, int]
|
||||||
|
vllm_tensorized: bool
|
||||||
|
verify_hash: Optional[bool] = False
|
||||||
|
num_readers: Optional[int] = 1
|
||||||
|
encryption_keyfile: Optional[str] = None
|
||||||
|
s3_access_key_id: Optional[str] = None
|
||||||
|
s3_secret_access_key: Optional[str] = None
|
||||||
|
s3_endpoint: Optional[str] = None
|
||||||
|
model_class: Optional[torch.nn.Module] = None
|
||||||
|
hf_config: Optional[PretrainedConfig] = None
|
||||||
|
dtype: Union[str, torch.dtype] = None
|
||||||
|
|
||||||
|
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||||
|
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||||
|
tensorizer_args = {
|
||||||
|
"tensorizer_uri": self.tensorizer_uri,
|
||||||
|
"vllm_tensorized": self.vllm_tensorized,
|
||||||
|
"verify_hash": self.verify_hash,
|
||||||
|
"num_readers": self.num_readers,
|
||||||
|
"encryption_keyfile": self.encryption_keyfile,
|
||||||
|
"s3_access_key_id": self.s3_access_key_id,
|
||||||
|
"s3_secret_access_key": self.s3_secret_access_key,
|
||||||
|
"s3_endpoint": self.s3_endpoint,
|
||||||
|
}
|
||||||
|
return TensorizerArgs(**tensorizer_args)
|
||||||
|
|
||||||
|
def verify_with_parallel_config(
|
||||||
|
self,
|
||||||
|
parallel_config: "ParallelConfig",
|
||||||
|
) -> None:
|
||||||
|
if (parallel_config.tensor_parallel_size > 1
|
||||||
|
and self.tensorizer_uri is not None):
|
||||||
|
raise ValueError(
|
||||||
|
"Loading to multiple GPUs is not currently supported with "
|
||||||
|
"vLLM-serialized models. Please set tensor_parallel_size=1."
|
||||||
|
" or use a non-vLLM-serialized model, such as a "
|
||||||
|
"serialized Hugging Face `PretrainedModel`.")
|
||||||
|
|
||||||
|
def verify_with_model_config(self, model_config) -> None:
|
||||||
|
if (model_config.quantization is not None
|
||||||
|
and self.tensorizer_uri is not None):
|
||||||
|
from vllm.model_executor.tensorizer_loader import (
|
||||||
|
tensorizer_warning)
|
||||||
|
tensorizer_warning(
|
||||||
|
"Loading a model using Tensorizer with quantization on vLLM"
|
||||||
|
" is unstable and may lead to errors.")
|
||||||
|
|
||||||
|
if (model_config.load_format != "tensorizer"
|
||||||
|
and self.tensorizer_uri is not None):
|
||||||
|
raise ValueError(
|
||||||
|
"A tensorizer uri was passed for tensorizer loading, but the "
|
||||||
|
f"load format was set to {model_config.load_format}. "
|
||||||
|
"Please set the load format to 'tensorizer' to use "
|
||||||
|
f"tensorizer args.")
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"half": torch.float16,
|
"half": torch.float16,
|
||||||
"float16": torch.float16,
|
"float16": torch.float16,
|
||||||
@ -1029,6 +1093,7 @@ class EngineConfig:
|
|||||||
lora_config: Optional[LoRAConfig]
|
lora_config: Optional[LoRAConfig]
|
||||||
vision_language_config: Optional[VisionLanguageConfig]
|
vision_language_config: Optional[VisionLanguageConfig]
|
||||||
speculative_config: Optional[SpeculativeConfig]
|
speculative_config: Optional[SpeculativeConfig]
|
||||||
|
tensorizer_config: Optional[TensorizerConfig]
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Verify configs are valid & consistent with each other.
|
"""Verify configs are valid & consistent with each other.
|
||||||
@ -1036,6 +1101,11 @@ class EngineConfig:
|
|||||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||||
|
|
||||||
|
if self.tensorizer_config:
|
||||||
|
self.tensorizer_config.verify_with_parallel_config(
|
||||||
|
self.parallel_config)
|
||||||
|
self.tensorizer_config.verify_with_model_config(self.model_config)
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.lora_config.verify_with_model_config(self.model_config)
|
self.lora_config.verify_with_model_config(self.model_config)
|
||||||
self.lora_config.verify_with_scheduler_config(
|
self.lora_config.verify_with_scheduler_config(
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import io
|
||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import BinaryIO, Optional, Union
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
SpeculativeConfig, TokenizerPoolConfig,
|
SpeculativeConfig, TensorizerConfig,
|
||||||
VisionLanguageConfig)
|
TokenizerPoolConfig, VisionLanguageConfig)
|
||||||
|
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||||
from vllm.utils import str_to_int_tuple
|
from vllm.utils import str_to_int_tuple
|
||||||
|
|
||||||
|
|
||||||
@ -58,12 +61,22 @@ class EngineArgs:
|
|||||||
num_gpu_blocks_override: Optional[int] = None
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = 0
|
||||||
|
|
||||||
|
# Tensorizer configuration parameters
|
||||||
|
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
|
||||||
|
bytes, os.PathLike, int] = None
|
||||||
|
vllm_tensorized: bool = False
|
||||||
|
verify_hash: Optional[bool] = False
|
||||||
|
num_readers: Optional[int] = 1
|
||||||
|
encryption_keyfile: Optional[str] = None
|
||||||
|
s3_access_key_id: Optional[str] = None
|
||||||
|
s3_secret_access_key: Optional[str] = None
|
||||||
|
s3_endpoint: Optional[str] = None
|
||||||
|
|
||||||
# Related to Vision-language models such as llava
|
# Related to Vision-language models such as llava
|
||||||
image_input_type: Optional[str] = None
|
image_input_type: Optional[str] = None
|
||||||
image_token_id: Optional[int] = None
|
image_token_id: Optional[int] = None
|
||||||
image_input_shape: Optional[str] = None
|
image_input_shape: Optional[str] = None
|
||||||
image_feature_size: Optional[int] = None
|
image_feature_size: Optional[int] = None
|
||||||
|
|
||||||
scheduler_delay_factor: float = 0.0
|
scheduler_delay_factor: float = 0.0
|
||||||
enable_chunked_prefill: bool = False
|
enable_chunked_prefill: bool = False
|
||||||
|
|
||||||
@ -135,7 +148,9 @@ class EngineArgs:
|
|||||||
'--load-format',
|
'--load-format',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.load_format,
|
default=EngineArgs.load_format,
|
||||||
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
choices=[
|
||||||
|
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
|
||||||
|
],
|
||||||
help='The format of the model weights to load. '
|
help='The format of the model weights to load. '
|
||||||
'"auto" will try to load the weights in the safetensors format '
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
'and fall back to the pytorch bin format if safetensors format '
|
'and fall back to the pytorch bin format if safetensors format '
|
||||||
@ -145,7 +160,10 @@ class EngineArgs:
|
|||||||
'"npcache" will load the weights in pytorch format and store '
|
'"npcache" will load the weights in pytorch format and store '
|
||||||
'a numpy cache to speed up the loading. '
|
'a numpy cache to speed up the loading. '
|
||||||
'"dummy" will initialize the weights with random values, '
|
'"dummy" will initialize the weights with random values, '
|
||||||
'which is mainly for profiling.')
|
'which is mainly for profiling.'
|
||||||
|
'"tensorizer" will load the weights using tensorizer from CoreWeave'
|
||||||
|
'which assumes tensorizer_uri is set to the location of the '
|
||||||
|
'serialized weights.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dtype',
|
'--dtype',
|
||||||
type=str,
|
type=str,
|
||||||
@ -403,6 +421,7 @@ class EngineArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
help='The number of speculative tokens to sample from '
|
help='The number of speculative tokens to sample from '
|
||||||
'the draft model in speculative decoding')
|
'the draft model in speculative decoding')
|
||||||
|
parser = TensorizerArgs.add_cli_args(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -465,6 +484,17 @@ class EngineArgs:
|
|||||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||||
|
|
||||||
|
tensorizer_config = TensorizerConfig(
|
||||||
|
tensorizer_uri=self.tensorizer_uri,
|
||||||
|
vllm_tensorized=self.vllm_tensorized,
|
||||||
|
verify_hash=self.verify_hash,
|
||||||
|
num_readers=self.num_readers,
|
||||||
|
encryption_keyfile=self.encryption_keyfile,
|
||||||
|
s3_access_key_id=self.s3_access_key_id,
|
||||||
|
s3_secret_access_key=self.s3_secret_access_key,
|
||||||
|
s3_endpoint=self.s3_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
if self.image_input_type:
|
if self.image_input_type:
|
||||||
if (not self.image_token_id or not self.image_input_shape
|
if (not self.image_token_id or not self.image_input_shape
|
||||||
or not self.image_feature_size):
|
or not self.image_feature_size):
|
||||||
@ -488,7 +518,8 @@ class EngineArgs:
|
|||||||
device_config=device_config,
|
device_config=device_config,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
vision_language_config=vision_language_config,
|
vision_language_config=vision_language_config,
|
||||||
speculative_config=speculative_config)
|
speculative_config=speculative_config,
|
||||||
|
tensorizer_config=tensorizer_config)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer
|
|||||||
import vllm
|
import vllm
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VisionLanguageConfig)
|
TensorizerConfig, VisionLanguageConfig)
|
||||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.metrics import StatLogger, Stats
|
from vllm.engine.metrics import StatLogger, Stats
|
||||||
@ -74,6 +74,7 @@ class LLMEngine:
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
|
tensorizer_config: Optional[TensorizerConfig],
|
||||||
executor_class: Type[ExecutorBase],
|
executor_class: Type[ExecutorBase],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
@ -110,6 +111,7 @@ class LLMEngine:
|
|||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
self.speculative_config = speculative_config
|
self.speculative_config = speculative_config
|
||||||
|
self.tensorizer_config = tensorizer_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
|
||||||
self._init_tokenizer()
|
self._init_tokenizer()
|
||||||
@ -125,6 +127,7 @@ class LLMEngine:
|
|||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
vision_language_config=vision_language_config,
|
vision_language_config=vision_language_config,
|
||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
|
tensorizer_config=tensorizer_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._initialize_kv_caches()
|
self._initialize_kv_caches()
|
||||||
@ -264,6 +267,9 @@ class LLMEngine:
|
|||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||||
|
if self.tensorizer_config:
|
||||||
|
self.tensorizer_config.verify_with_parallel_config(
|
||||||
|
self.parallel_config)
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.lora_config.verify_with_model_config(self.model_config)
|
self.lora_config.verify_with_model_config(self.model_config)
|
||||||
self.lora_config.verify_with_scheduler_config(
|
self.lora_config.verify_with_scheduler_config(
|
||||||
|
@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VisionLanguageConfig)
|
TensorizerConfig, VisionLanguageConfig)
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -15,17 +15,14 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class GPUExecutor(ExecutorBase):
|
class GPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
|
||||||
self,
|
parallel_config: ParallelConfig,
|
||||||
model_config: ModelConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
cache_config: CacheConfig,
|
device_config: DeviceConfig,
|
||||||
parallel_config: ParallelConfig,
|
lora_config: Optional[LoRAConfig],
|
||||||
scheduler_config: SchedulerConfig,
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
device_config: DeviceConfig,
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
lora_config: Optional[LoRAConfig],
|
tensorizer_config: Optional[TensorizerConfig]) -> None:
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
|
||||||
) -> None:
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
@ -33,6 +30,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
self.vision_language_config = vision_language_config
|
self.vision_language_config = vision_language_config
|
||||||
|
self.tensorizer_config = tensorizer_config
|
||||||
|
|
||||||
assert (not speculative_config
|
assert (not speculative_config
|
||||||
), "Speculative decoding not yet supported for GPU backend"
|
), "Speculative decoding not yet supported for GPU backend"
|
||||||
@ -61,6 +59,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
vision_language_config=self.vision_language_config,
|
vision_language_config=self.vision_language_config,
|
||||||
|
tensorizer_config=self.tensorizer_config,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
|
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VisionLanguageConfig)
|
TensorizerConfig, VisionLanguageConfig)
|
||||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -42,6 +42,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
|
tensorizer_config: Optional[TensorizerConfig],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@ -50,6 +51,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
self.vision_language_config = vision_language_config
|
self.vision_language_config = vision_language_config
|
||||||
|
self.tensorizer_config = tensorizer_config
|
||||||
assert (not speculative_config
|
assert (not speculative_config
|
||||||
), "Speculative decoding not yet supported for RayGPU backend."
|
), "Speculative decoding not yet supported for RayGPU backend."
|
||||||
|
|
||||||
@ -171,6 +173,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
vision_language_config=vision_language_config,
|
vision_language_config=vision_language_config,
|
||||||
|
tensorizer_config=self.tensorizer_config,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Initialize the driver worker with the Worker class.
|
# Initialize the driver worker with the Worker class.
|
||||||
@ -187,6 +190,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
vision_language_config=self.vision_language_config,
|
vision_language_config=self.vision_language_config,
|
||||||
|
tensorizer_config=self.tensorizer_config,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,11 +3,14 @@ import contextlib
|
|||||||
from typing import Tuple, Type
|
from typing import Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import DeviceConfig, ModelConfig
|
from vllm.config import DeviceConfig, ModelConfig
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
|
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
|
||||||
|
from vllm.model_executor.tensorizer_loader import (
|
||||||
|
ParameterizedLoadFormat, is_vllm_serialized_tensorizer,
|
||||||
|
load_with_tensorizer)
|
||||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||||
initialize_dummy_weights)
|
initialize_dummy_weights)
|
||||||
|
|
||||||
@ -51,6 +54,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
|||||||
**kwargs) -> nn.Module:
|
**kwargs) -> nn.Module:
|
||||||
lora_config = kwargs.get("lora_config", None)
|
lora_config = kwargs.get("lora_config", None)
|
||||||
vision_language_config = kwargs.get("vision_language_config", None)
|
vision_language_config = kwargs.get("vision_language_config", None)
|
||||||
|
tensorizer_config = kwargs.get("tensorizer_config", None)
|
||||||
model_class = _get_model_architecture(model_config)[0]
|
model_class = _get_model_architecture(model_config)[0]
|
||||||
|
|
||||||
# Get the (maybe quantized) linear method.
|
# Get the (maybe quantized) linear method.
|
||||||
@ -71,33 +75,54 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
|||||||
f"{model_config.dtype} is not supported for quantization "
|
f"{model_config.dtype} is not supported for quantization "
|
||||||
f"method {model_config.quantization}. Supported dtypes: "
|
f"method {model_config.quantization}. Supported dtypes: "
|
||||||
f"{supported_dtypes}")
|
f"{supported_dtypes}")
|
||||||
|
|
||||||
linear_method = quant_config.get_linear_method()
|
linear_method = quant_config.get_linear_method()
|
||||||
|
|
||||||
with _set_default_torch_dtype(model_config.dtype):
|
with _set_default_torch_dtype(model_config.dtype):
|
||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
# The weights will be initialized as empty tensors.
|
# The weights will be initialized as empty tensors.
|
||||||
|
extra_kwargs = {}
|
||||||
|
if hasattr(model_class, "supported_lora_modules"):
|
||||||
|
extra_kwargs["lora_config"] = lora_config
|
||||||
|
elif lora_config:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {model_class.__name__} does not support LoRA, "
|
||||||
|
"but LoRA is enabled. Support for this model may "
|
||||||
|
"be added in the future. If this is important to you, "
|
||||||
|
"please open an issue on github.")
|
||||||
|
elif model_class in _VISION_MODEL_CLASSES:
|
||||||
|
extra_kwargs["vision_language_config"] = vision_language_config
|
||||||
|
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
if hasattr(model_class, "supported_lora_modules"):
|
if (model_config.load_format == "tensorizer"
|
||||||
model = model_class(model_config.hf_config, linear_method,
|
and is_vllm_serialized_tensorizer(tensorizer_config)):
|
||||||
lora_config)
|
extra_kwargs["linear_method"] = linear_method
|
||||||
elif lora_config:
|
tensorizer_config.model_class = model_class
|
||||||
raise ValueError(
|
tensorizer_config.hf_config = model_config.hf_config
|
||||||
f"Model {model_class.__name__} does not support LoRA, "
|
tensorizer_config.dtype = model_config.dtype
|
||||||
"but LoRA is enabled. Support for this model may "
|
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
|
||||||
"be added in the future. If this is important to you, "
|
return model.eval()
|
||||||
"please open an issue on github.")
|
model = model_class(config=model_config.hf_config,
|
||||||
else:
|
linear_method=linear_method,
|
||||||
if model_class not in _VISION_MODEL_CLASSES:
|
**extra_kwargs)
|
||||||
model = model_class(model_config.hf_config, linear_method)
|
|
||||||
else:
|
|
||||||
model = model_class(model_config.hf_config,
|
|
||||||
vision_language_config, linear_method)
|
|
||||||
if model_config.load_format == "dummy":
|
if model_config.load_format == "dummy":
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
else:
|
else:
|
||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(model_config.model, model_config.download_dir,
|
if model_config.load_format == "tensorizer":
|
||||||
model_config.load_format, model_config.revision)
|
# Provide a dynamic load format for `model.load_weights`
|
||||||
|
# to retain tensorizer args from CLI.
|
||||||
|
model_config.load_format = ParameterizedLoadFormat(
|
||||||
|
model_config.load_format)
|
||||||
|
model_config.load_format.params = (
|
||||||
|
tensorizer_config._construct_tensorizer_args())
|
||||||
|
|
||||||
|
model.load_weights(
|
||||||
|
model_config.model,
|
||||||
|
model_config.download_dir,
|
||||||
|
model_config.load_format,
|
||||||
|
model_config.revision,
|
||||||
|
)
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
319
vllm/model_executor/tensorizer_loader.py
Normal file
319
vllm/model_executor/tensorizer_loader.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import typing
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.config import TensorizerConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
|
||||||
|
tensorizer_load_fail = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tensorizer import (DecryptionParams, EncryptionParams,
|
||||||
|
TensorDeserializer, TensorSerializer)
|
||||||
|
from tensorizer.stream_io import open_stream
|
||||||
|
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||||||
|
no_init_or_tensor)
|
||||||
|
except ImportError:
|
||||||
|
tensorizer_load_fail = True
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
|
||||||
|
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
|
||||||
|
'no_init_or_tensor'
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
|
||||||
|
**extra_kwargs) -> nn.Module:
|
||||||
|
tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
|
||||||
|
return tensorizer.deserialize()
|
||||||
|
|
||||||
|
|
||||||
|
def tensorizer_warning(message: str):
|
||||||
|
return warnings.warn(message, category=PerformanceWarning, stacklevel=2)
|
||||||
|
|
||||||
|
|
||||||
|
def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
|
||||||
|
if tensorizer_config is None:
|
||||||
|
return False
|
||||||
|
return tensorizer_config.vllm_tensorized
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterizedLoadFormat(str):
|
||||||
|
__slots__ = "params"
|
||||||
|
|
||||||
|
|
||||||
|
class PerformanceWarning(UserWarning):
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (f"{super().__str__()}"
|
||||||
|
" (set the VLLM_SILENCE_PERFORMANCE_WARNINGS"
|
||||||
|
" environment variable to hide this)")
|
||||||
|
|
||||||
|
|
||||||
|
if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower()
|
||||||
|
not in ("", "0", "n", "no", "off", "disable")):
|
||||||
|
warnings.simplefilter("ignore", category=PerformanceWarning)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TensorizerArgs:
|
||||||
|
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
||||||
|
str, bytes, os.PathLike, int]
|
||||||
|
vllm_tensorized: bool
|
||||||
|
verify_hash: Optional[bool] = False
|
||||||
|
num_readers: Optional[int] = 1
|
||||||
|
encryption_keyfile: Optional[str] = None
|
||||||
|
s3_access_key_id: Optional[str] = None
|
||||||
|
s3_secret_access_key: Optional[str] = None
|
||||||
|
s3_endpoint: Optional[str] = None
|
||||||
|
"""
|
||||||
|
Args for the TensorizerAgent class. These are used to configure the behavior
|
||||||
|
of the TensorDeserializer when loading tensors from a serialized model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensorizer_uri: Path to serialized model tensors. Can be a local file
|
||||||
|
path or a S3 URI.
|
||||||
|
vllm_tensorized: If True, indicates that the serialized model is a
|
||||||
|
vLLM model. This is used to determine the behavior of the
|
||||||
|
TensorDeserializer when loading tensors from a serialized model.
|
||||||
|
It is far faster to deserialize a vLLM model as it utilizes
|
||||||
|
tensorizer's optimized GPU loading.
|
||||||
|
verify_hash: If True, the hashes of each tensor will be verified against
|
||||||
|
the hashes stored in the metadata. A `HashMismatchError` will be
|
||||||
|
raised if any of the hashes do not match.
|
||||||
|
num_readers: Controls how many threads are allowed to read concurrently
|
||||||
|
from the source file. Default is 1. This greatly increases
|
||||||
|
performance.
|
||||||
|
encryption_keyfile: File path to a binary file containing a
|
||||||
|
binary key to use for decryption. `None` (the default) means
|
||||||
|
no decryption. See the example script in
|
||||||
|
examples/tensorize_vllm_model.py.
|
||||||
|
s3_access_key_id: The access key for the S3 bucket. Can also be set via
|
||||||
|
the S3_ACCESS_KEY_ID environment variable.
|
||||||
|
s3_secret_access_key: The secret access key for the S3 bucket. Can also
|
||||||
|
be set via the S3_SECRET_ACCESS_KEY environment variable.
|
||||||
|
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
|
||||||
|
S3_ENDPOINT_URL environment variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.file_obj = self.tensorizer_uri
|
||||||
|
self.s3_access_key_id = (self.s3_access_key_id
|
||||||
|
or os.environ.get("S3_ACCESS_KEY_ID")) or None
|
||||||
|
self.s3_secret_access_key = (
|
||||||
|
self.s3_secret_access_key
|
||||||
|
or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
|
||||||
|
self.s3_endpoint = (self.s3_endpoint
|
||||||
|
or os.environ.get("S3_ENDPOINT_URL")) or None
|
||||||
|
self.stream_params = {
|
||||||
|
"s3_access_key_id": self.s3_access_key_id,
|
||||||
|
"s3_secret_access_key": self.s3_secret_access_key,
|
||||||
|
"s3_endpoint": self.s3_endpoint,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Omitting self.dtype and self.device as this behaves weirdly
|
||||||
|
self.deserializer_params = {
|
||||||
|
"verify_hash": self.verify_hash,
|
||||||
|
"encryption": self.encryption_keyfile,
|
||||||
|
"num_readers": self.num_readers
|
||||||
|
}
|
||||||
|
if self.encryption_keyfile:
|
||||||
|
with open_stream(
|
||||||
|
self.encryption_keyfile,
|
||||||
|
**self.stream_params,
|
||||||
|
) as stream:
|
||||||
|
key = stream.read()
|
||||||
|
decryption_params = DecryptionParams.from_key(key)
|
||||||
|
self.deserializer_params['encryption'] = decryption_params
|
||||||
|
|
||||||
|
def add_cli_args(
|
||||||
|
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||||
|
"""Tensorizer CLI arguments"""
|
||||||
|
|
||||||
|
# Create the argument group
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
'tensorizer options',
|
||||||
|
description=('Options for configuring the behavior of the'
|
||||||
|
' tensorizer deserializer when '
|
||||||
|
'--load-format=tensorizer'))
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--tensorizer-uri",
|
||||||
|
help="Path to serialized model tensors. Can be a local file path,"
|
||||||
|
" or an HTTP(S) or S3 URI.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--verify-hash",
|
||||||
|
action="store_true",
|
||||||
|
help="If enabled, the hashes of each tensor will be verified"
|
||||||
|
" against the hashes stored in the file metadata. An exception"
|
||||||
|
" will be raised if any of the hashes do not match.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--encryption-keyfile",
|
||||||
|
default=None,
|
||||||
|
help="The file path to a binary file containing a binary key to "
|
||||||
|
"use for decryption. Can be a file path or S3 network URI.")
|
||||||
|
group.add_argument(
|
||||||
|
"--num-readers",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="Controls how many threads are allowed to read concurrently "
|
||||||
|
"from the source file.")
|
||||||
|
group.add_argument(
|
||||||
|
"--s3-access-key-id",
|
||||||
|
default=None,
|
||||||
|
help="The access key for the S3 bucket. Can also be set via the "
|
||||||
|
"S3_ACCESS_KEY_ID environment variable.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--s3-secret-access-key",
|
||||||
|
default=None,
|
||||||
|
help="The secret access key for the S3 bucket. Can also be set via "
|
||||||
|
"the S3_SECRET_ACCESS_KEY environment variable.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--s3-endpoint",
|
||||||
|
default=None,
|
||||||
|
help="The endpoint for the S3 bucket. Can also be set via the "
|
||||||
|
"S3_ENDPOINT_URL environment variable.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--vllm-tensorized",
|
||||||
|
action="store_true",
|
||||||
|
help="If enabled, indicates that the serialized model is a vLLM "
|
||||||
|
"model. This is used to determine the behavior of the "
|
||||||
|
"TensorDeserializer when loading tensors from a "
|
||||||
|
"serialized model.")
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
|
||||||
|
# Get the list of attributes of this dataclass.
|
||||||
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
|
# Set the attributes from the parsed arguments.
|
||||||
|
tensorizer_args = cls(**{
|
||||||
|
attr: getattr(args, attr)
|
||||||
|
for attr in attrs if hasattr(args, attr)
|
||||||
|
})
|
||||||
|
return tensorizer_args
|
||||||
|
|
||||||
|
|
||||||
|
class TensorizerAgent:
|
||||||
|
"""
|
||||||
|
A class for performing tensorizer deserializations specifically for
|
||||||
|
vLLM models using plaid_mode. Uses TensorizerArgs to configure the
|
||||||
|
behavior of the TensorDeserializer when loading tensors from a serialized
|
||||||
|
model. For deserializations of HuggingFace models, TensorDeserializer is
|
||||||
|
instead used as an iterator directly in the func hf_model_weights_iterator
|
||||||
|
in vllm/model_executor/weight_utils.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tensorizer_config: TensorizerConfig,
|
||||||
|
linear_method: LinearMethodBase, **extra_kwargs):
|
||||||
|
self.tensorizer_config = tensorizer_config
|
||||||
|
self.tensorizer_args = (
|
||||||
|
self.tensorizer_config._construct_tensorizer_args())
|
||||||
|
self.extra_kwargs = extra_kwargs
|
||||||
|
if extra_kwargs.get("linear_method", None) is not None:
|
||||||
|
self.linear_method = extra_kwargs["linear_method"]
|
||||||
|
else:
|
||||||
|
self.linear_method = linear_method
|
||||||
|
self.model = self._init_model()
|
||||||
|
|
||||||
|
if tensorizer_load_fail:
|
||||||
|
raise ImportError(
|
||||||
|
"Tensorizer is not installed. Please install tensorizer "
|
||||||
|
"to use this feature with `pip install vllm[tensorizer]`.")
|
||||||
|
|
||||||
|
def _init_model(self):
|
||||||
|
model_args = self.tensorizer_config.hf_config
|
||||||
|
model_args.torch_dtype = self.tensorizer_config.dtype
|
||||||
|
with no_init_or_tensor():
|
||||||
|
return self.tensorizer_config.model_class(
|
||||||
|
config=model_args,
|
||||||
|
linear_method=self.linear_method,
|
||||||
|
**self.extra_kwargs)
|
||||||
|
|
||||||
|
def _resize_lora_embeddings(self):
|
||||||
|
"""Modify LoRA embedding layers to use bigger tensors
|
||||||
|
to allow for adapter added tokens."""
|
||||||
|
for child in self.model.modules():
|
||||||
|
if (isinstance(child, VocabParallelEmbedding)
|
||||||
|
and child.weight.shape[0] <
|
||||||
|
child.num_embeddings_per_partition):
|
||||||
|
new_weight = torch.empty(child.num_embeddings_per_partition,
|
||||||
|
child.embedding_dim,
|
||||||
|
dtype=child.weight.dtype,
|
||||||
|
device=child.weight.device)
|
||||||
|
new_weight[:child.weight.shape[0]].copy_(child.weight.data)
|
||||||
|
new_weight[child.weight.shape[0]:].fill_(0)
|
||||||
|
child.weight.data = new_weight
|
||||||
|
|
||||||
|
def _check_tensors_on_meta_device(self):
|
||||||
|
for tensor in self.model.state_dict().values():
|
||||||
|
if tensor.device.type == 'meta':
|
||||||
|
raise ValueError(
|
||||||
|
"The serialized model contains tensors on the meta device,"
|
||||||
|
" indicating that some tensors were not loaded properly."
|
||||||
|
" Please check that the parameters of the model being"
|
||||||
|
" specified match that of the serialized model, such as"
|
||||||
|
" its quantization.")
|
||||||
|
|
||||||
|
def deserialize(self):
|
||||||
|
"""
|
||||||
|
Deserialize the model using the TensorDeserializer. This method is
|
||||||
|
specifically for vLLM models using tensorizer's plaid_mode.
|
||||||
|
|
||||||
|
The deserializer makes use of tensorizer_args.stream_params
|
||||||
|
to configure the behavior of the stream when loading tensors from a
|
||||||
|
serialized model. The deserializer_params are used to configure the
|
||||||
|
behavior of the TensorDeserializer when loading tensors themselves.
|
||||||
|
Documentation on these params can be found in TensorizerArgs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The deserialized model.
|
||||||
|
"""
|
||||||
|
before_mem = get_mem_usage()
|
||||||
|
# Lazy load the tensors from S3 into the model.
|
||||||
|
start = time.perf_counter()
|
||||||
|
with open_stream(
|
||||||
|
self.tensorizer_args.tensorizer_uri,
|
||||||
|
mode="rb",
|
||||||
|
**self.tensorizer_args.stream_params,
|
||||||
|
) as stream, TensorDeserializer(
|
||||||
|
stream,
|
||||||
|
dtype=self.tensorizer_config.dtype,
|
||||||
|
**self.tensorizer_args.deserializer_params) as deserializer:
|
||||||
|
deserializer.load_into_module(self.model)
|
||||||
|
end = time.perf_counter()
|
||||||
|
|
||||||
|
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||||
|
duration = end - start
|
||||||
|
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||||
|
after_mem = get_mem_usage()
|
||||||
|
deserializer.close()
|
||||||
|
logger.info(f"Deserialized {total_bytes_str} in "
|
||||||
|
f"{end - start:0.2f}s, {per_second}/s")
|
||||||
|
logger.info(f"Memory usage before: {before_mem}")
|
||||||
|
logger.info(f"Memory usage after: {after_mem}")
|
||||||
|
|
||||||
|
self._check_tensors_on_meta_device()
|
||||||
|
self._resize_lora_embeddings()
|
||||||
|
return self.model.eval()
|
@ -5,7 +5,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Iterable, Iterator, List, Optional, Tuple
|
from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import filelock
|
import filelock
|
||||||
import huggingface_hub.constants
|
import huggingface_hub.constants
|
||||||
@ -161,7 +161,8 @@ def prepare_hf_model_weights(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
) -> Tuple[str, List[str], bool]:
|
) -> Tuple[str, List[str], bool]:
|
||||||
# Download model weights from huggingface.
|
# Download model weights from huggingface.
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path) \
|
||||||
|
and load_format != "tensorizer"
|
||||||
use_safetensors = False
|
use_safetensors = False
|
||||||
# Some quantized models use .pt files for storing the weights.
|
# Some quantized models use .pt files for storing the weights.
|
||||||
if load_format == "auto":
|
if load_format == "auto":
|
||||||
@ -173,13 +174,15 @@ def prepare_hf_model_weights(
|
|||||||
allow_patterns = ["*.pt"]
|
allow_patterns = ["*.pt"]
|
||||||
elif load_format == "npcache":
|
elif load_format == "npcache":
|
||||||
allow_patterns = ["*.bin"]
|
allow_patterns = ["*.bin"]
|
||||||
|
elif load_format == "tensorizer":
|
||||||
|
allow_patterns = ["*.tensors"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown load_format: {load_format}")
|
raise ValueError(f"Unknown load_format: {load_format}")
|
||||||
|
|
||||||
if fall_back_to_pt:
|
if fall_back_to_pt:
|
||||||
allow_patterns += ["*.pt"]
|
allow_patterns += ["*.pt"]
|
||||||
|
|
||||||
if not is_local:
|
if not is_local and load_format != "tensorizer":
|
||||||
# Before we download we look at that is available:
|
# Before we download we look at that is available:
|
||||||
fs = HfFileSystem()
|
fs = HfFileSystem()
|
||||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||||
@ -224,6 +227,9 @@ def prepare_hf_model_weights(
|
|||||||
if not any(f.endswith(x) for x in blacklist)
|
if not any(f.endswith(x) for x in blacklist)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if load_format == "tensorizer":
|
||||||
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
if len(hf_weights_files) == 0:
|
if len(hf_weights_files) == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||||
@ -234,7 +240,7 @@ def prepare_hf_model_weights(
|
|||||||
def hf_model_weights_iterator(
|
def hf_model_weights_iterator(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto",
|
load_format: Union[Tuple, str] = "auto",
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
fall_back_to_pt: Optional[bool] = True,
|
fall_back_to_pt: Optional[bool] = True,
|
||||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
@ -277,6 +283,26 @@ def hf_model_weights_iterator(
|
|||||||
with open(param_path, "rb") as f:
|
with open(param_path, "rb") as f:
|
||||||
param = np.load(f)
|
param = np.load(f)
|
||||||
yield name, torch.from_numpy(param)
|
yield name, torch.from_numpy(param)
|
||||||
|
elif load_format == "tensorizer":
|
||||||
|
from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
|
||||||
|
open_stream,
|
||||||
|
tensorizer_warning)
|
||||||
|
tensorizer_args = load_format.params
|
||||||
|
tensorizer_warning(
|
||||||
|
"Deserializing HuggingFace models is not optimized for "
|
||||||
|
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||||
|
"Consider deserializing a vLLM model instead for faster "
|
||||||
|
"load times. See the examples/tensorize_vllm_model.py example "
|
||||||
|
"script for serializing vLLM models.")
|
||||||
|
|
||||||
|
deserializer_args = tensorizer_args.deserializer_params
|
||||||
|
stream_params = tensorizer_args.stream_params
|
||||||
|
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
||||||
|
with TensorDeserializer(stream, **deserializer_args,
|
||||||
|
device="cpu") as state:
|
||||||
|
for name, param in state.items():
|
||||||
|
yield name, param
|
||||||
|
del state
|
||||||
elif use_safetensors:
|
elif use_safetensors:
|
||||||
for st_file in hf_weights_files:
|
for st_file in hf_weights_files:
|
||||||
with safe_open(st_file, framework="pt") as f:
|
with safe_open(st_file, framework="pt") as f:
|
||||||
|
@ -10,7 +10,8 @@ import torch.nn as nn
|
|||||||
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
||||||
get_attn_backend)
|
get_attn_backend)
|
||||||
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig, VisionLanguageConfig)
|
SchedulerConfig, TensorizerConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
||||||
from vllm.distributed.device_communicators import (custom_all_reduce,
|
from vllm.distributed.device_communicators import (custom_all_reduce,
|
||||||
pynccl_utils)
|
pynccl_utils)
|
||||||
@ -111,11 +112,13 @@ class ModelRunner:
|
|||||||
kv_cache_dtype: Optional[str] = "auto",
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
|
tensorizer_config: Optional[TensorizerConfig] = None,
|
||||||
):
|
):
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.tensorizer_config = tensorizer_config
|
||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
# model_config can be None in tests/samplers/test_sampler.py.
|
# model_config can be None in tests/samplers/test_sampler.py.
|
||||||
@ -158,7 +161,9 @@ class ModelRunner:
|
|||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
vision_language_config=self.vision_language_config,
|
vision_language_config=self.vision_language_config,
|
||||||
parallel_config=self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
scheduler_config=self.scheduler_config)
|
scheduler_config=self.scheduler_config,
|
||||||
|
tensorizer_config=self.tensorizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
self.model_memory_usage = m.consumed_memory
|
self.model_memory_usage = m.consumed_memory
|
||||||
logger.info(f"Loading model weights took "
|
logger.info(f"Loading model weights took "
|
||||||
|
@ -7,7 +7,8 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
ParallelConfig, SchedulerConfig, TensorizerConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import (broadcast_tensor_dict,
|
from vllm.distributed import (broadcast_tensor_dict,
|
||||||
ensure_model_parallel_initialized,
|
ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
@ -42,6 +43,7 @@ class Worker(WorkerBase):
|
|||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
|
tensorizer_config: Optional[TensorizerConfig] = None,
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
@ -53,6 +55,7 @@ class Worker(WorkerBase):
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.distributed_init_method = distributed_init_method
|
self.distributed_init_method = distributed_init_method
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.tensorizer_config = tensorizer_config
|
||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
assert self.rank == 0, "The driver worker must have rank 0."
|
assert self.rank == 0, "The driver worker must have rank 0."
|
||||||
@ -70,7 +73,9 @@ class Worker(WorkerBase):
|
|||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
is_driver_worker=is_driver_worker,
|
is_driver_worker=is_driver_worker,
|
||||||
vision_language_config=vision_language_config)
|
vision_language_config=vision_language_config,
|
||||||
|
tensorizer_config=tensorizer_config,
|
||||||
|
)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# initialize_cache.
|
# initialize_cache.
|
||||||
self.cache_engine = None
|
self.cache_engine = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user