[Frontend] [Core] feat: Add model loading using tensorizer (#3476)

This commit is contained in:
Sanger Steel 2024-04-13 20:13:01 -04:00 committed by GitHub
parent 989ae2538d
commit 711a000255
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1351 additions and 51 deletions

View File

@ -91,6 +91,9 @@ steps:
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4
- label: Tensorizer Test
command: apt-get install curl libsodium23 && pytest -v -s tensorizer
- label: Metrics Test
command: pytest -v -s metrics

View File

@ -83,6 +83,7 @@ autodoc_mock_imports = [
"vllm._C",
"numpy",
"tqdm",
"tensorizer",
]
for mock_target in autodoc_mock_imports:

View File

@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM:
Directory to download and load the weights, default to the default cache dir of huggingface.
.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
The format of the model weights to load.
@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM:
* "safetensors" will load the weights in the safetensors format.
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
* "dummy" will initialize the weights with random values, mainly for profiling.
* "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`.
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}

View File

@ -0,0 +1,254 @@
import argparse
import dataclasses
import os
import time
import uuid
from functools import partial
from typing import Type
import torch
import torch.nn as nn
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
TensorSerializer, stream_io)
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoConfig, PretrainedConfig
from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.tensorizer_loader import TensorizerArgs
# yapf conflicts with isort for this docstring
# yapf: disable
"""
tensorize_vllm_model.py is a script that can be used to serialize and
deserialize vLLM models. These models can be loaded using tensorizer directly
to the GPU extremely quickly. Tensor encryption and decryption is also
supported, although libsodium must be installed to use it. Install
vllm with tensorizer support using `pip install vllm[tensorizer]`.
To serialize a model, you can run something like this:
python tensorize_vllm_model.py \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
serialize \
--serialized-directory s3://my-bucket/ \
--suffix vllm
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
and saves it to your S3 bucket. A local directory can also be used.
You can also encrypt the model weights with a randomly-generated key by
providing a `--keyfile` argument.
To deserialize a model, you can run something like this:
python tensorize_vllm_model.py \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
deserialize \
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
Which downloads the model tensors from your S3 bucket and deserializes them.
To provide S3 credentials, you can provide `--s3-access-key-id` and
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
the OpenAI entrypoint, as arguments for LLM(), or as environment variables
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.
For more information on the available arguments, run
`python tensorize_vllm_model.py --help`.
"""
def parse_args():
parser = argparse.ArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU "
"extremely quickly. Tensor encryption and decryption is "
"also supported, although libsodium must be installed to "
"use it.")
parser = EngineArgs.add_cli_args(parser)
subparsers = parser.add_subparsers(dest='command')
serialize_parser = subparsers.add_parser(
'serialize', help="Serialize a model to `--serialized-directory`")
serialize_parser.add_argument(
"--suffix",
type=str,
required=False,
help=(
"The suffix to append to the serialized model directory, which is "
"used to construct the location of the serialized model tensors, "
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
"`--suffix` is `v1`, the serialized model tensors will be "
"saved to "
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
"If none is provided, a random UUID will be used."))
serialize_parser.add_argument(
"--serialized-directory",
type=str,
required=True,
help="The directory to serialize the model to. "
"This can be a local directory or S3 URI. The path to where the "
"tensors are saved is a combination of the supplied `dir` and model "
"reference ID. For instance, if `dir` is the serialized directory, "
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
"where `suffix` is given by `--suffix` or a random UUID if not "
"provided.")
serialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Encrypt the model weights with a randomly-generated binary key,"
" and save the key at this path"))
deserialize_parser = subparsers.add_parser(
'deserialize',
help=("Deserialize a model from `--path-to-tensors`"
" to verify it can be loaded and used."))
deserialize_parser.add_argument(
"--path-to-tensors",
type=str,
required=True,
help="The local path or S3 URI to the model tensors to deserialize. ")
deserialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Path to a binary key to use to decrypt the model weights,"
" if the model was serialized with encryption"))
return parser.parse_args()
def make_model_contiguous(model):
# Ensure tensors are saved in memory contiguously
for param in model.parameters():
param.data = param.data.contiguous()
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def serialize():
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)
model = (engine.model_executor.driver_worker.
model_runner.model)
encryption_params = EncryptionParams.random() if keyfile else None
if keyfile:
with _write_stream(keyfile) as stream:
stream.write(encryption_params.key)
with _write_stream(model_path) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
serializer.close()
print("Serialization complete. Model tensors saved to", model_path)
if keyfile:
print("Key saved to", keyfile)
def deserialize():
config = AutoConfig.from_pretrained(model_ref)
with no_init_or_tensor():
model_class = _get_vllm_model_architecture(config)
model = model_class(config)
before_mem = get_mem_usage()
start = time.time()
if keyfile:
with _read_stream(keyfile) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
tensorizer_args.deserializer_params['encryption'] = \
decryption_params
with (_read_stream(model_path)) as stream, TensorDeserializer(
stream, **tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(model)
end = time.time()
# Brag about how fast we are.
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
print(
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
)
print(f"Memory usage before: {before_mem}")
print(f"Memory usage after: {after_mem}")
return model
args = parse_args()
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
or None)
s3_secret_access_key = (args.s3_secret_access_key
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
_read_stream, _write_stream = (partial(
stream_io.open_stream,
mode=mode,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
) for mode in ("rb", "wb+"))
model_ref = args.model
model_name = model_ref.split("/")[1]
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
torch.distributed.init_process_group(world_size=1, rank=0)
initialize_model_parallel()
keyfile = args.keyfile if args.keyfile else None
if args.command == "serialize":
input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
model_path = f"{base_path}/model.tensors"
serialize()
elif args.command == "deserialize":
tensorizer_args = TensorizerArgs.from_cli_args(args)
model_path = args.path_to_tensors
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")

View File

@ -14,6 +14,7 @@ types-setuptools
# testing
pytest
tensorizer==2.9.0a0
pytest-forked
pytest-asyncio
pytest-rerunfailures

View File

@ -405,6 +405,9 @@ setup(
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
extras_require={
"optional": ["tensorizer==2.9.0a1"],
},
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
package_data=package_data,
)

View File

View File

@ -0,0 +1,245 @@
import argparse
import dataclasses
import os
import time
import uuid
from functools import partial
from typing import Type
import torch
import torch.nn as nn
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
TensorSerializer, stream_io)
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoConfig, PretrainedConfig
from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.tensorizer_loader import TensorizerArgs
# yapf conflicts with isort for this docstring
# yapf: disable
"""
tensorize_vllm_model.py is a script that can be used to serialize and
deserialize vLLM models. These models can be loaded using tensorizer directly
to the GPU extremely quickly. Tensor encryption and decryption is also
supported, although libsodium must be installed to use it. Install
vllm with tensorizer support using `pip install vllm[tensorizer]`.
To serialize a model, you can run something like this:
python tensorize_vllm_model.py \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
serialize \
--serialized-directory s3://my-bucket/ \
--suffix vllm
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
and saves it to your S3 bucket. A local directory can also be used.
You can also encrypt the model weights with a randomly-generated key by
providing a `--keyfile` argument.
To deserialize a model, you can run something like this:
python tensorize_vllm_model.py \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
deserialize \
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
Which downloads the model tensors from your S3 bucket and deserializes them.
To provide S3 credentials, you can provide `--s3-access-key-id` and
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
the OpenAI entrypoint, as arguments for LLM(), or as environment variables
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.
For more information on the available arguments, run
`python tensorize_vllm_model.py --help`.
"""
def parse_args():
parser = argparse.ArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU "
"extremely quickly. Tensor encryption and decryption is "
"also supported, although libsodium must be installed to "
"use it.")
parser = EngineArgs.add_cli_args(parser)
subparsers = parser.add_subparsers(dest='command')
serialize_parser = subparsers.add_parser(
'serialize', help="Serialize a model to `--serialized-directory`")
serialize_parser.add_argument(
"--suffix",
type=str,
required=False,
help=(
"The suffix to append to the serialized model directory, which is "
"used to construct the location of the serialized model tensors, "
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
"`--suffix` is `v1`, the serialized model tensors will be "
"saved to "
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
"If none is provided, a random UUID will be used."))
serialize_parser.add_argument(
"--serialized-directory",
type=str,
required=True)
serialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Encrypt the model weights with a randomly-generated binary key,"
" and save the key at this path"))
deserialize_parser = subparsers.add_parser(
'deserialize',
help=("Deserialize a model from `--path-to-tensors`"
" to verify it can be loaded and used."))
deserialize_parser.add_argument(
"--path-to-tensors",
type=str,
required=True,
help="The local path or S3 URI to the model tensors to deserialize. ")
deserialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Path to a binary key to use to decrypt the model weights,"
" if the model was serialized with encryption"))
return parser.parse_args()
def make_model_contiguous(model):
# Ensure tensors are saved in memory contiguously
for param in model.parameters():
param.data = param.data.contiguous()
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def serialize():
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)
model = (engine.model_executor.driver_worker.
model_runner.model)
encryption_params = EncryptionParams.random() if keyfile else None
if keyfile:
with _write_stream(keyfile) as stream:
stream.write(encryption_params.key)
with _write_stream(model_path) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
serializer.close()
print("Serialization complete. Model tensors saved to", model_path)
if keyfile:
print("Key saved to", keyfile)
def deserialize():
config = AutoConfig.from_pretrained(model_ref)
with no_init_or_tensor():
model_class = _get_vllm_model_architecture(config)
model = model_class(config)
before_mem = get_mem_usage()
start = time.time()
if keyfile:
with _read_stream(keyfile) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
tensorizer_args.deserializer_params['encryption'] = \
decryption_params
with (_read_stream(model_path)) as stream, TensorDeserializer(
stream, **tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(model)
end = time.time()
# Brag about how fast we are.
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
print(
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
)
print(f"Memory usage before: {before_mem}")
print(f"Memory usage after: {after_mem}")
return model
args = parse_args()
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
or None)
s3_secret_access_key = (args.s3_secret_access_key
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
_read_stream, _write_stream = (partial(
stream_io.open_stream,
mode=mode,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
) for mode in ("rb", "wb+"))
model_ref = args.model
model_name = model_ref.split("/")[1]
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
torch.distributed.init_process_group(world_size=1, rank=0)
initialize_model_parallel()
keyfile = args.keyfile if args.keyfile else None
if args.command == "serialize":
input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
model_path = f"{base_path}/model.tensors"
serialize()
elif args.command == "deserialize":
tensorizer_args = TensorizerArgs.from_cli_args(args)
model_path = args.path_to_tensors
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")

View File

@ -0,0 +1,302 @@
import gc
import subprocess
from unittest.mock import MagicMock, patch
import pytest
import torch
from tests.entrypoints.test_openai_server import ServerRunner
from vllm import SamplingParams
from vllm.config import TensorizerConfig
from vllm.model_executor.tensorizer_loader import (
EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer,
load_with_tensorizer, open_stream)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
model_ref = "facebook/opt-125m"
def is_curl_installed():
try:
subprocess.check_call(['curl', '--version'])
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
return config
@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_linear_method = MagicMock()
mock_agent_instance = mock_agent.return_value
mock_agent_instance.deserialize.return_value = MagicMock()
result = load_with_tensorizer(tensorizer_config,
linear_method=mock_linear_method)
mock_agent.assert_called_once_with(tensorizer_config,
linear_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value
def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
tensorizer_config.vllm_tensorized = True
result = is_vllm_serialized_tensorizer(tensorizer_config)
assert result is True
def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
tensorizer_config.vllm_tensorized = False
result = is_vllm_serialized_tensorizer(tensorizer_config)
assert result is False
def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
vllm_model = vllm_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
outputs = vllm_model.generate(prompts, sampling_params)
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(model)
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(model_ref,
load_format="tensorizer",
tensorizer_uri=model_path,
num_readers=1,
vllm_tensorized=True)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
# Assumes SamplingParams being seeded ensures the outputs are deterministic
assert outputs == deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
loaded_hf_model = vllm_runner(
model_ref,
tensorizer_uri=tensorized_path,
load_format="tensorizer",
num_readers=1,
vllm_tensorized=False,
s3_endpoint="object.ord1.coreweave.com",
)
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
assert deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path):
vllm_model = vllm_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
outputs = vllm_model.generate(prompts, sampling_params)
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
encryption_params = EncryptionParams.random()
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
with open_stream(key_path, "wb+") as stream:
stream.write(encryption_params.key)
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(model_ref,
tensorizer_uri=model_path,
load_format="tensorizer",
encryption_keyfile=key_path,
num_readers=1,
vllm_tensorized=True)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
# Assumes SamplingParams being seeded ensures the outputs are deterministic
assert outputs == deserialized_outputs
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
tmp_path):
hf_model = hf_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
max_tokens = 50
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(hf_model.model)
del hf_model
gc.collect()
torch.cuda.empty_cache()
loaded_hf_model = vllm_runner(model_ref,
tensorizer_uri=model_path,
load_format="tensorizer",
num_readers=1,
vllm_tensorized=False)
deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens)
assert outputs == deserialized_outputs
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
from huggingface_hub import snapshot_download
from examples.multilora_inference import (create_test_prompts,
process_requests)
model_ref = "meta-llama/Llama-2-7b-hf"
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
test_prompts = create_test_prompts(lora_path)
# Serialize model before deserializing and binding LoRA adapters
vllm_model = vllm_runner(model_ref, )
model_path = tmp_path / (model_ref + ".tensors")
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(model)
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(
model_ref,
tensorizer_uri=model_path,
load_format="tensorizer",
num_readers=1,
vllm_tensorized=True,
enable_lora=True,
max_loras=1,
max_lora_rank=8,
max_cpu_loras=2,
max_num_seqs=50,
max_model_len=1000,
)
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
assert loaded_vllm_model
def test_load_without_tensorizer_load_format(vllm_runner):
with pytest.raises(ValueError):
vllm_runner(model_ref, tensorizer_uri="test")
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorize_vllm_model(tmp_path):
# Test serialize command
serialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
tmp_path, "--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command
assert result.returncode == 0, (f"Serialize command failed with output:"
f"\n{result.stdout}\n{result.stderr}")
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
# Test deserialize command
deserialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors",
path_to_tensors
]
result = subprocess.run(deserialize_args, capture_output=True, text=True)
assert result.returncode == 0, (f"Deserialize command failed with output:"
f"\n{result.stdout}\n{result.stderr}")
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_openai_apiserver_with_tensorizer(tmp_path):
## Serialize model
serialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
tmp_path, "--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command
assert result.returncode == 0, (f"Serialize command failed with output:"
f"\n{result.stdout}\n{result.stderr}")
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
## Start OpenAI API server
openai_args = [
"--model", model_ref, "--dtype", "float16", "--load-format",
"tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized",
"--port", "8000"
]
server = ServerRunner.remote(openai_args)
print("Server ready.")
assert server.ready.remote()
def test_raise_value_error_on_invalid_load_format(vllm_runner):
with pytest.raises(ValueError):
vllm_runner(model_ref,
load_format="safetensors",
tensorizer_uri="test")
def test_tensorizer_with_tp(vllm_runner):
with pytest.raises(ValueError):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
vllm_runner(
model_ref,
tensorizer_uri=tensorized_path,
load_format="tensorizer",
num_readers=1,
vllm_tensorized=False,
s3_endpoint="object.ord1.coreweave.com",
tensor_parallel_size=2,
)
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorizer_warn_quant(tmp_path):
model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
serialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
"serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
assert 'PerformanceWarning' in result.stderr

View File

@ -1,6 +1,8 @@
import enum
import io
import json
import os
import typing
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
@ -16,6 +18,8 @@ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
from vllm.model_executor.tensorizer_loader import TensorizerArgs
logger = init_logger(__name__)
_GB = 1 << 30
@ -139,13 +143,14 @@ class ModelConfig:
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy"
"auto", "pt", "safetensors", "npcache", "dummy", "tensorizer"
]
rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
"'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
"'dummy'.")
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in supported_load_format
@ -882,6 +887,65 @@ class VisionLanguageConfig:
f"{[x.name for x in cls.ImageInputType]}.") from e
@dataclass
class TensorizerConfig:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
str, bytes, os.PathLike, int]
vllm_tensorized: bool
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
model_class: Optional[torch.nn.Module] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Union[str, torch.dtype] = None
def _construct_tensorizer_args(self) -> "TensorizerArgs":
from vllm.model_executor.tensorizer_loader import TensorizerArgs
tensorizer_args = {
"tensorizer_uri": self.tensorizer_uri,
"vllm_tensorized": self.vllm_tensorized,
"verify_hash": self.verify_hash,
"num_readers": self.num_readers,
"encryption_keyfile": self.encryption_keyfile,
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if (parallel_config.tensor_parallel_size > 1
and self.tensorizer_uri is not None):
raise ValueError(
"Loading to multiple GPUs is not currently supported with "
"vLLM-serialized models. Please set tensor_parallel_size=1."
" or use a non-vLLM-serialized model, such as a "
"serialized Hugging Face `PretrainedModel`.")
def verify_with_model_config(self, model_config) -> None:
if (model_config.quantization is not None
and self.tensorizer_uri is not None):
from vllm.model_executor.tensorizer_loader import (
tensorizer_warning)
tensorizer_warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.")
if (model_config.load_format != "tensorizer"
and self.tensorizer_uri is not None):
raise ValueError(
"A tensorizer uri was passed for tensorizer loading, but the "
f"load format was set to {model_config.load_format}. "
"Please set the load format to 'tensorizer' to use "
f"tensorizer args.")
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
@ -1029,6 +1093,7 @@ class EngineConfig:
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig]
tensorizer_config: Optional[TensorizerConfig]
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
@ -1036,6 +1101,11 @@ class EngineConfig:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.tensorizer_config:
self.tensorizer_config.verify_with_parallel_config(
self.parallel_config)
self.tensorizer_config.verify_with_model_config(self.model_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(

View File

@ -1,12 +1,15 @@
import argparse
import dataclasses
import io
import os
from dataclasses import dataclass
from typing import Optional
from typing import BinaryIO, Optional, Union
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig)
SpeculativeConfig, TensorizerConfig,
TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.tensorizer_loader import TensorizerArgs
from vllm.utils import str_to_int_tuple
@ -58,12 +61,22 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
# Tensorizer configuration parameters
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
bytes, os.PathLike, int] = None
vllm_tensorized: bool = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
@ -135,7 +148,9 @@ class EngineArgs:
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
],
help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
@ -145,7 +160,10 @@ class EngineArgs:
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.')
'which is mainly for profiling.'
'"tensorizer" will load the weights using tensorizer from CoreWeave'
'which assumes tensorizer_uri is set to the location of the '
'serialized weights.')
parser.add_argument(
'--dtype',
type=str,
@ -403,6 +421,7 @@ class EngineArgs:
default=None,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding')
parser = TensorizerArgs.add_cli_args(parser)
return parser
@classmethod
@ -465,6 +484,17 @@ class EngineArgs:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
tensorizer_config = TensorizerConfig(
tensorizer_uri=self.tensorizer_uri,
vllm_tensorized=self.vllm_tensorized,
verify_hash=self.verify_hash,
num_readers=self.num_readers,
encryption_keyfile=self.encryption_keyfile,
s3_access_key_id=self.s3_access_key_id,
s3_secret_access_key=self.s3_secret_access_key,
s3_endpoint=self.s3_endpoint,
)
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
@ -488,7 +518,8 @@ class EngineArgs:
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config)
speculative_config=speculative_config,
tensorizer_config=tensorizer_config)
@dataclass

View File

@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
TensorizerConfig, VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
@ -74,6 +74,7 @@ class LLMEngine:
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@ -110,6 +111,7 @@ class LLMEngine:
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.tensorizer_config = tensorizer_config
self.log_stats = log_stats
self._init_tokenizer()
@ -125,6 +127,7 @@ class LLMEngine:
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
tensorizer_config=tensorizer_config,
)
self._initialize_kv_caches()
@ -264,6 +267,9 @@ class LLMEngine:
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.tensorizer_config:
self.tensorizer_config.verify_with_parallel_config(
self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(

View File

@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
TensorizerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -15,17 +15,14 @@ logger = init_logger(__name__)
class GPUExecutor(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
tensorizer_config: Optional[TensorizerConfig]) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
@ -33,6 +30,7 @@ class GPUExecutor(ExecutorBase):
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.tensorizer_config = tensorizer_config
assert (not speculative_config
), "Speculative decoding not yet supported for GPU backend"
@ -61,6 +59,7 @@ class GPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
tensorizer_config=self.tensorizer_config,
is_driver_worker=True,
)
self.driver_worker.init_device()

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
TensorizerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
@ -42,6 +42,7 @@ class RayGPUExecutor(ExecutorBase):
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
@ -50,6 +51,7 @@ class RayGPUExecutor(ExecutorBase):
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.tensorizer_config = tensorizer_config
assert (not speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
@ -171,6 +173,7 @@ class RayGPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method,
lora_config=lora_config,
vision_language_config=vision_language_config,
tensorizer_config=self.tensorizer_config,
))
# Initialize the driver worker with the Worker class.
@ -187,6 +190,7 @@ class RayGPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
tensorizer_config=self.tensorizer_config,
is_driver_worker=True,
)

View File

@ -3,11 +3,14 @@ import contextlib
from typing import Tuple, Type
import torch
import torch.nn as nn
from torch import nn
from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
from vllm.model_executor.tensorizer_loader import (
ParameterizedLoadFormat, is_vllm_serialized_tensorizer,
load_with_tensorizer)
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
@ -51,6 +54,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
vision_language_config = kwargs.get("vision_language_config", None)
tensorizer_config = kwargs.get("tensorizer_config", None)
model_class = _get_model_architecture(model_config)[0]
# Get the (maybe quantized) linear method.
@ -71,33 +75,54 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
with torch.device(device_config.device):
extra_kwargs = {}
if hasattr(model_class, "supported_lora_modules"):
model = model_class(model_config.hf_config, linear_method,
lora_config)
extra_kwargs["lora_config"] = lora_config
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
else:
if model_class not in _VISION_MODEL_CLASSES:
model = model_class(model_config.hf_config, linear_method)
else:
model = model_class(model_config.hf_config,
vision_language_config, linear_method)
elif model_class in _VISION_MODEL_CLASSES:
extra_kwargs["vision_language_config"] = vision_language_config
with torch.device(device_config.device):
if (model_config.load_format == "tensorizer"
and is_vllm_serialized_tensorizer(tensorizer_config)):
extra_kwargs["linear_method"] = linear_method
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
return model.eval()
model = model_class(config=model_config.hf_config,
linear_method=linear_method,
**extra_kwargs)
if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format, model_config.revision)
if model_config.load_format == "tensorizer":
# Provide a dynamic load format for `model.load_weights`
# to retain tensorizer args from CLI.
model_config.load_format = ParameterizedLoadFormat(
model_config.load_format)
model_config.load_format.params = (
tensorizer_config._construct_tensorizer_args())
model.load_weights(
model_config.model,
model_config.download_dir,
model_config.load_format,
model_config.revision,
)
return model.eval()

View File

@ -0,0 +1,319 @@
import argparse
import dataclasses
import io
import os
import time
import typing
import warnings
from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch import nn
from vllm.config import TensorizerConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
tensorizer_load_fail = False
try:
from tensorizer import (DecryptionParams, EncryptionParams,
TensorDeserializer, TensorSerializer)
from tensorizer.stream_io import open_stream
from tensorizer.utils import (convert_bytes, get_mem_usage,
no_init_or_tensor)
except ImportError:
tensorizer_load_fail = True
__all__ = [
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
'no_init_or_tensor'
]
logger = init_logger(__name__)
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
**extra_kwargs) -> nn.Module:
tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
return tensorizer.deserialize()
def tensorizer_warning(message: str):
return warnings.warn(message, category=PerformanceWarning, stacklevel=2)
def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
if tensorizer_config is None:
return False
return tensorizer_config.vllm_tensorized
class ParameterizedLoadFormat(str):
__slots__ = "params"
class PerformanceWarning(UserWarning):
def __str__(self):
return (f"{super().__str__()}"
" (set the VLLM_SILENCE_PERFORMANCE_WARNINGS"
" environment variable to hide this)")
if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower()
not in ("", "0", "n", "no", "off", "disable")):
warnings.simplefilter("ignore", category=PerformanceWarning)
@dataclass
class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
str, bytes, os.PathLike, int]
vllm_tensorized: bool
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
"""
Args for the TensorizerAgent class. These are used to configure the behavior
of the TensorDeserializer when loading tensors from a serialized model.
Args:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading.
verify_hash: If True, the hashes of each tensor will be verified against
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is 1. This greatly increases
performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
"""
def __post_init__(self):
self.file_obj = self.tensorizer_uri
self.s3_access_key_id = (self.s3_access_key_id
or os.environ.get("S3_ACCESS_KEY_ID")) or None
self.s3_secret_access_key = (
self.s3_secret_access_key
or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
self.s3_endpoint = (self.s3_endpoint
or os.environ.get("S3_ENDPOINT_URL")) or None
self.stream_params = {
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
# Omitting self.dtype and self.device as this behaves weirdly
self.deserializer_params = {
"verify_hash": self.verify_hash,
"encryption": self.encryption_keyfile,
"num_readers": self.num_readers
}
if self.encryption_keyfile:
with open_stream(
self.encryption_keyfile,
**self.stream_params,
) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
self.deserializer_params['encryption'] = decryption_params
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Tensorizer CLI arguments"""
# Create the argument group
group = parser.add_argument_group(
'tensorizer options',
description=('Options for configuring the behavior of the'
' tensorizer deserializer when '
'--load-format=tensorizer'))
group.add_argument(
"--tensorizer-uri",
help="Path to serialized model tensors. Can be a local file path,"
" or an HTTP(S) or S3 URI.",
)
group.add_argument(
"--verify-hash",
action="store_true",
help="If enabled, the hashes of each tensor will be verified"
" against the hashes stored in the file metadata. An exception"
" will be raised if any of the hashes do not match.",
)
group.add_argument(
"--encryption-keyfile",
default=None,
help="The file path to a binary file containing a binary key to "
"use for decryption. Can be a file path or S3 network URI.")
group.add_argument(
"--num-readers",
default=1,
type=int,
help="Controls how many threads are allowed to read concurrently "
"from the source file.")
group.add_argument(
"--s3-access-key-id",
default=None,
help="The access key for the S3 bucket. Can also be set via the "
"S3_ACCESS_KEY_ID environment variable.",
)
group.add_argument(
"--s3-secret-access-key",
default=None,
help="The secret access key for the S3 bucket. Can also be set via "
"the S3_SECRET_ACCESS_KEY environment variable.",
)
group.add_argument(
"--s3-endpoint",
default=None,
help="The endpoint for the S3 bucket. Can also be set via the "
"S3_ENDPOINT_URL environment variable.",
)
group.add_argument(
"--vllm-tensorized",
action="store_true",
help="If enabled, indicates that the serialized model is a vLLM "
"model. This is used to determine the behavior of the "
"TensorDeserializer when loading tensors from a "
"serialized model.")
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
tensorizer_args = cls(**{
attr: getattr(args, attr)
for attr in attrs if hasattr(args, attr)
})
return tensorizer_args
class TensorizerAgent:
"""
A class for performing tensorizer deserializations specifically for
vLLM models using plaid_mode. Uses TensorizerArgs to configure the
behavior of the TensorDeserializer when loading tensors from a serialized
model. For deserializations of HuggingFace models, TensorDeserializer is
instead used as an iterator directly in the func hf_model_weights_iterator
in vllm/model_executor/weight_utils.py
"""
def __init__(self, tensorizer_config: TensorizerConfig,
linear_method: LinearMethodBase, **extra_kwargs):
self.tensorizer_config = tensorizer_config
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs
if extra_kwargs.get("linear_method", None) is not None:
self.linear_method = extra_kwargs["linear_method"]
else:
self.linear_method = linear_method
self.model = self._init_model()
if tensorizer_load_fail:
raise ImportError(
"Tensorizer is not installed. Please install tensorizer "
"to use this feature with `pip install vllm[tensorizer]`.")
def _init_model(self):
model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype
with no_init_or_tensor():
return self.tensorizer_config.model_class(
config=model_args,
linear_method=self.linear_method,
**self.extra_kwargs)
def _resize_lora_embeddings(self):
"""Modify LoRA embedding layers to use bigger tensors
to allow for adapter added tokens."""
for child in self.model.modules():
if (isinstance(child, VocabParallelEmbedding)
and child.weight.shape[0] <
child.num_embeddings_per_partition):
new_weight = torch.empty(child.num_embeddings_per_partition,
child.embedding_dim,
dtype=child.weight.dtype,
device=child.weight.device)
new_weight[:child.weight.shape[0]].copy_(child.weight.data)
new_weight[child.weight.shape[0]:].fill_(0)
child.weight.data = new_weight
def _check_tensors_on_meta_device(self):
for tensor in self.model.state_dict().values():
if tensor.device.type == 'meta':
raise ValueError(
"The serialized model contains tensors on the meta device,"
" indicating that some tensors were not loaded properly."
" Please check that the parameters of the model being"
" specified match that of the serialized model, such as"
" its quantization.")
def deserialize(self):
"""
Deserialize the model using the TensorDeserializer. This method is
specifically for vLLM models using tensorizer's plaid_mode.
The deserializer makes use of tensorizer_args.stream_params
to configure the behavior of the stream when loading tensors from a
serialized model. The deserializer_params are used to configure the
behavior of the TensorDeserializer when loading tensors themselves.
Documentation on these params can be found in TensorizerArgs
Returns:
nn.Module: The deserialized model.
"""
before_mem = get_mem_usage()
# Lazy load the tensors from S3 into the model.
start = time.perf_counter()
with open_stream(
self.tensorizer_args.tensorizer_uri,
mode="rb",
**self.tensorizer_args.stream_params,
) as stream, TensorDeserializer(
stream,
dtype=self.tensorizer_config.dtype,
**self.tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(self.model)
end = time.perf_counter()
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info(f"Deserialized {total_bytes_str} in "
f"{end - start:0.2f}s, {per_second}/s")
logger.info(f"Memory usage before: {before_mem}")
logger.info(f"Memory usage after: {after_mem}")
self._check_tensors_on_meta_device()
self._resize_lora_embeddings()
return self.model.eval()

View File

@ -5,7 +5,7 @@ import hashlib
import json
import os
from collections import defaultdict
from typing import Any, Iterable, Iterator, List, Optional, Tuple
from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
import filelock
import huggingface_hub.constants
@ -161,7 +161,8 @@ def prepare_hf_model_weights(
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
is_local = os.path.isdir(model_name_or_path) \
and load_format != "tensorizer"
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
@ -173,13 +174,15 @@ def prepare_hf_model_weights(
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
elif load_format == "tensorizer":
allow_patterns = ["*.tensors"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if not is_local:
if not is_local and load_format != "tensorizer":
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
@ -224,6 +227,9 @@ def prepare_hf_model_weights(
if not any(f.endswith(x) for x in blacklist)
]
if load_format == "tensorizer":
return hf_folder, hf_weights_files, use_safetensors
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
@ -234,7 +240,7 @@ def prepare_hf_model_weights(
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
load_format: Union[Tuple, str] = "auto",
revision: Optional[str] = None,
fall_back_to_pt: Optional[bool] = True,
) -> Iterator[Tuple[str, torch.Tensor]]:
@ -277,6 +283,26 @@ def hf_model_weights_iterator(
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
elif load_format == "tensorizer":
from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
open_stream,
tensorizer_warning)
tensorizer_args = load_format.params
tensorizer_warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models.")
deserializer_args = tensorizer_args.deserializer_params
stream_params = tensorizer_args.stream_params
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state:
for name, param in state.items():
yield name, param
del state
elif use_safetensors:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:

View File

@ -10,7 +10,8 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
SchedulerConfig, TensorizerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
pynccl_utils)
@ -111,11 +112,13 @@ class ModelRunner:
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
tensorizer_config: Optional[TensorizerConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.tensorizer_config = tensorizer_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
@ -158,7 +161,9 @@ class ModelRunner:
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
scheduler_config=self.scheduler_config,
tensorizer_config=self.tensorizer_config,
)
self.model_memory_usage = m.consumed_memory
logger.info(f"Loading model weights took "

View File

@ -7,7 +7,8 @@ import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
ParallelConfig, SchedulerConfig, TensorizerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
@ -42,6 +43,7 @@ class Worker(WorkerBase):
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
tensorizer_config: Optional[TensorizerConfig] = None,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
@ -53,6 +55,7 @@ class Worker(WorkerBase):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.tensorizer_config = tensorizer_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
@ -70,7 +73,9 @@ class Worker(WorkerBase):
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config)
vision_language_config=vision_language_config,
tensorizer_config=tensorizer_config,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine = None