[Core] Refactor model loading code (#4097)
This commit is contained in:
parent
05434764cd
commit
69e1d2fb69
@ -92,7 +92,7 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: Tensorizer Test
|
||||
command: apt-get install curl libsodium23 && pytest -v -s tensorizer
|
||||
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
|
||||
|
||||
- label: Metrics Test
|
||||
command: pytest -v -s metrics
|
||||
|
@ -11,7 +11,7 @@ from safetensors.torch import safe_open
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/weight_utils.py
|
||||
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||
# The main differences are that we add the NPZ format and simplify
|
||||
# its functionality drastically for our purposes (e.g. we assume that
|
||||
# the quantized model exists locally and there is no need to download it)
|
||||
@ -71,7 +71,7 @@ def _prepare_hf_weights(
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/weight_utils.py
|
||||
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||
def _hf_tensorfile_iterator(filename: str, load_format: str,
|
||||
use_safetensors: bool):
|
||||
if load_format == "npz":
|
||||
|
@ -16,8 +16,8 @@ from transformers import AutoConfig, PretrainedConfig
|
||||
from vllm.distributed import initialize_model_parallel
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
|
@ -153,11 +153,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||
cleanup()
|
||||
get_model_old = get_model
|
||||
|
||||
def get_model_patched(model_config, device_config, **kwargs):
|
||||
return get_model_old(model_config,
|
||||
device_config,
|
||||
lora_config=LoRAConfig(max_loras=4,
|
||||
max_lora_rank=8))
|
||||
def get_model_patched(*, model_config, device_config, **kwargs):
|
||||
kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8)
|
||||
return get_model_old(model_config=model_config,
|
||||
device_config=device_config,
|
||||
**kwargs)
|
||||
|
||||
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||
|
@ -3,8 +3,8 @@ import random
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig)
|
||||
from vllm.lora.models import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.worker.worker import Worker
|
||||
@ -18,12 +18,14 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
),
|
||||
load_config=LoadConfig(
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
),
|
||||
parallel_config=ParallelConfig(1, 1, False),
|
||||
scheduler_config=SchedulerConfig(32, 32, 32),
|
||||
device_config=DeviceConfig("cuda"),
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
import huggingface_hub.constants
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.weight_utils import enable_hf_transfer
|
||||
from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer
|
||||
|
||||
|
||||
def test_hf_transfer_auto_activation():
|
||||
|
@ -36,8 +36,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
|
||||
model_path,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
@ -49,8 +47,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
|
||||
model_path,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
|
@ -32,7 +32,12 @@ def _prepare_test(
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
model_runner = ModelRunner(model_config=None,
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
return input_tensor, fake_logits, sampler, model_runner
|
||||
|
||||
|
||||
@ -591,7 +596,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
model_runner = ModelRunner(model_config=None,
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
|
||||
generation_model = GenerationMixin()
|
||||
generation_config = GenerationConfig(top_k=top_k,
|
||||
|
@ -118,6 +118,7 @@ def create_worker(cls: type,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
@ -16,8 +16,8 @@ from transformers import AutoConfig, PretrainedConfig
|
||||
from vllm.distributed import initialize_model_parallel
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
@ -74,7 +74,7 @@ def parse_args():
|
||||
"extremely quickly. Tensor encryption and decryption is "
|
||||
"also supported, although libsodium must be installed to "
|
||||
"use it.")
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser))
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
|
||||
serialize_parser = subparsers.add_parser(
|
@ -1,16 +1,19 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from tests.entrypoints.test_openai_server import ServerRunner
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import TensorizerConfig
|
||||
from vllm.model_executor.tensorizer_loader import (
|
||||
EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer,
|
||||
load_with_tensorizer, open_stream)
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
EncryptionParams, TensorizerConfig, TensorSerializer,
|
||||
is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -22,6 +25,8 @@ prompts = [
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||||
|
||||
model_ref = "facebook/opt-125m"
|
||||
tensorize_model_for_testing_script = os.path.join(
|
||||
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
|
||||
|
||||
|
||||
def is_curl_installed():
|
||||
@ -38,7 +43,7 @@ def tensorizer_config():
|
||||
return config
|
||||
|
||||
|
||||
@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent')
|
||||
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
|
||||
def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
||||
mock_linear_method = MagicMock()
|
||||
mock_agent_instance = mock_agent.return_value
|
||||
@ -81,11 +86,13 @@ def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
|
||||
del vllm_model, model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_vllm_model = vllm_runner(model_ref,
|
||||
load_format="tensorizer",
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True)
|
||||
loaded_vllm_model = vllm_runner(
|
||||
model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True),
|
||||
)
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
# Assumes SamplingParams being seeded ensures the outputs are deterministic
|
||||
@ -97,14 +104,14 @@ def test_can_deserialize_s3(vllm_runner):
|
||||
model_ref = "EleutherAI/pythia-1.4b"
|
||||
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
|
||||
|
||||
loaded_hf_model = vllm_runner(
|
||||
model_ref,
|
||||
tensorizer_uri=tensorized_path,
|
||||
load_format="tensorizer",
|
||||
num_readers=1,
|
||||
vllm_tensorized=False,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
)
|
||||
loaded_hf_model = vllm_runner(model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=tensorized_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=False,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
))
|
||||
|
||||
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
|
||||
|
||||
@ -131,11 +138,12 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_vllm_model = vllm_runner(model_ref,
|
||||
tensorizer_uri=model_path,
|
||||
load_format="tensorizer",
|
||||
encryption_keyfile=key_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True)
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
encryption_keyfile=key_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True))
|
||||
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
@ -156,10 +164,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_hf_model = vllm_runner(model_ref,
|
||||
tensorizer_uri=model_path,
|
||||
load_format="tensorizer",
|
||||
num_readers=1,
|
||||
vllm_tensorized=False)
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=False))
|
||||
|
||||
deserialized_outputs = loaded_hf_model.generate_greedy(
|
||||
prompts, max_tokens=max_tokens)
|
||||
@ -190,10 +199,12 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||
torch.cuda.empty_cache()
|
||||
loaded_vllm_model = vllm_runner(
|
||||
model_ref,
|
||||
tensorizer_uri=model_path,
|
||||
load_format="tensorizer",
|
||||
num_readers=1,
|
||||
vllm_tensorized=True,
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True,
|
||||
),
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=8,
|
||||
@ -208,16 +219,18 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||
|
||||
def test_load_without_tensorizer_load_format(vllm_runner):
|
||||
with pytest.raises(ValueError):
|
||||
vllm_runner(model_ref, tensorizer_uri="test")
|
||||
vllm_runner(model_ref,
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri="test", vllm_tensorized=False))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_tensorize_vllm_model(tmp_path):
|
||||
# Test serialize command
|
||||
serialize_args = [
|
||||
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
|
||||
tmp_path, "--suffix", "tests"
|
||||
"python3", tensorize_model_for_testing_script, "--model", model_ref,
|
||||
"--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
|
||||
"--suffix", "tests"
|
||||
]
|
||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||
print(result.stdout) # Print the output of the serialize command
|
||||
@ -229,8 +242,8 @@ def test_tensorize_vllm_model(tmp_path):
|
||||
|
||||
# Test deserialize command
|
||||
deserialize_args = [
|
||||
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||
model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors",
|
||||
"python3", tensorize_model_for_testing_script, "--model", model_ref,
|
||||
"--dtype", "float16", "deserialize", "--path-to-tensors",
|
||||
path_to_tensors
|
||||
]
|
||||
result = subprocess.run(deserialize_args, capture_output=True, text=True)
|
||||
@ -242,9 +255,9 @@ def test_tensorize_vllm_model(tmp_path):
|
||||
def test_openai_apiserver_with_tensorizer(tmp_path):
|
||||
## Serialize model
|
||||
serialize_args = [
|
||||
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
|
||||
tmp_path, "--suffix", "tests"
|
||||
"python3", tensorize_model_for_testing_script, "--model", model_ref,
|
||||
"--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
|
||||
"--suffix", "tests"
|
||||
]
|
||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||
print(result.stdout) # Print the output of the serialize command
|
||||
@ -253,25 +266,47 @@ def test_openai_apiserver_with_tensorizer(tmp_path):
|
||||
f"\n{result.stdout}\n{result.stderr}")
|
||||
|
||||
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
|
||||
model_loader_extra_config = {
|
||||
"tensorizer_uri": path_to_tensors,
|
||||
"vllm_tensorized": True
|
||||
}
|
||||
|
||||
## Start OpenAI API server
|
||||
openai_args = [
|
||||
"--model", model_ref, "--dtype", "float16", "--load-format",
|
||||
"tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized",
|
||||
"--port", "8000"
|
||||
"tensorizer", "--model-loader-extra-config",
|
||||
json.dumps(model_loader_extra_config), "--port", "8000"
|
||||
]
|
||||
|
||||
server = ServerRunner.remote(openai_args)
|
||||
|
||||
assert ray.get(server.ready.remote())
|
||||
print("Server ready.")
|
||||
assert server.ready.remote()
|
||||
|
||||
client = openai.OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="token-abc123",
|
||||
)
|
||||
completion = client.completions.create(model=model_ref,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
assert completion.choices[0].text is not None and len(
|
||||
completion.choices[0].text) >= 5
|
||||
assert completion.choices[0].finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
|
||||
|
||||
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
||||
with pytest.raises(ValueError):
|
||||
vllm_runner(model_ref,
|
||||
load_format="safetensors",
|
||||
tensorizer_uri="test")
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri="test", vllm_tensorized=False))
|
||||
|
||||
|
||||
def test_tensorizer_with_tp(vllm_runner):
|
||||
@ -281,22 +316,12 @@ def test_tensorizer_with_tp(vllm_runner):
|
||||
|
||||
vllm_runner(
|
||||
model_ref,
|
||||
tensorizer_uri=tensorized_path,
|
||||
load_format="tensorizer",
|
||||
num_readers=1,
|
||||
vllm_tensorized=False,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=tensorized_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=False,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
),
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_tensorizer_warn_quant(tmp_path):
|
||||
model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
|
||||
serialize_args = [
|
||||
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||
model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
|
||||
"serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
|
||||
]
|
||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||
assert 'PerformanceWarning' in result.stderr
|
@ -11,8 +11,6 @@ def test_get_sliding_window():
|
||||
"Qwen/Qwen1.5-7B",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
@ -30,8 +28,6 @@ def test_get_sliding_window():
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
|
@ -37,7 +37,12 @@ def _prepare_test(
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
model_runner = ModelRunner(model_config=None,
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
return input_tensor, fake_logits, logits_processor, model_runner
|
||||
|
||||
|
||||
|
@ -12,7 +12,12 @@ def test_prepare_prompt(batch_size):
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(None, None, scheduler_config, None, None)
|
||||
model_runner = ModelRunner(model_config=None,
|
||||
parallel_config=None,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
@ -118,8 +123,6 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
@ -129,8 +132,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(model_config, None, scheduler_config, None,
|
||||
None)
|
||||
model_runner = ModelRunner(model_config=model_config,
|
||||
parallel_config=None,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
@ -205,14 +212,17 @@ def test_empty_seq_group():
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=False,
|
||||
)
|
||||
model_runner = ModelRunner(model_config, None, None, None, None)
|
||||
model_runner = ModelRunner(model_config=model_config,
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
seq_group_metadata_list = []
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
@ -251,8 +261,6 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
@ -262,11 +270,12 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=True)
|
||||
model_runner = ModelRunner(model_config,
|
||||
None,
|
||||
scheduler_config,
|
||||
None,
|
||||
None,
|
||||
model_runner = ModelRunner(model_config=model_config,
|
||||
parallel_config=None,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None,
|
||||
is_driver_worker=True)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
|
@ -23,6 +23,7 @@ def test_swap() -> None:
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
201
vllm/config.py
201
vllm/config.py
@ -1,9 +1,7 @@
|
||||
import enum
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -18,10 +16,14 @@ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# If true, will load models from ModelScope instead of Hugging Face Hub.
|
||||
VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE",
|
||||
"False").lower() == "true"
|
||||
|
||||
_GB = 1 << 30
|
||||
|
||||
|
||||
@ -35,18 +37,6 @@ class ModelConfig:
|
||||
available, and "slow" will always use the slow tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
dtype: Data type for model weights and activations. The "auto" option
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
@ -83,8 +73,6 @@ class ModelConfig:
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
download_dir: Optional[str],
|
||||
load_format: str,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
@ -101,8 +89,6 @@ class ModelConfig:
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.download_dir = download_dir
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.code_revision = code_revision
|
||||
@ -113,64 +99,16 @@ class ModelConfig:
|
||||
self.max_context_len_to_capture = max_context_len_to_capture
|
||||
self.max_logprobs = max_logprobs
|
||||
|
||||
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
if not os.path.exists(model):
|
||||
model_path = snapshot_download(model_id=model,
|
||||
cache_dir=download_dir,
|
||||
revision=revision)
|
||||
else:
|
||||
model_path = model
|
||||
self.model = model_path
|
||||
self.download_dir = model_path
|
||||
self.tokenizer = model_path
|
||||
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
|
||||
max_model_len)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
supported_load_format = [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy", "tensorizer"
|
||||
]
|
||||
rocm_not_supported_load_format: List[str] = []
|
||||
if load_format not in supported_load_format:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
|
||||
"'dummy'.")
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f for f in supported_load_format
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format '{load_format}' is not supported in ROCm. "
|
||||
f"Supported load format are "
|
||||
f"{rocm_supported_load_format}")
|
||||
|
||||
# TODO: Remove this check once HF updates the pt weights of Mixtral.
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
# architectures can be None instead of []
|
||||
if architectures and "MixtralForCausalLM" in architectures \
|
||||
and load_format == "pt":
|
||||
raise ValueError(
|
||||
"Currently, the 'pt' format is not supported for Mixtral. "
|
||||
"Please use the 'safetensors' format instead. ")
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow"]:
|
||||
@ -471,6 +409,65 @@ class TokenizerPoolConfig:
|
||||
return tokenizer_pool_config
|
||||
|
||||
|
||||
class LoadFormat(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
PT = "pt"
|
||||
SAFETENSORS = "safetensors"
|
||||
NPCACHE = "npcache"
|
||||
DUMMY = "dummy"
|
||||
TENSORIZER = "tensorizer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
"tensorizer" will use CoreWeave's tensorizer library for
|
||||
fast weight loading.
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(
|
||||
default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
if isinstance(model_loader_extra_config, str):
|
||||
self.model_loader_extra_config = json.loads(
|
||||
model_loader_extra_config)
|
||||
self._verify_load_format()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
if not isinstance(self.load_format, str):
|
||||
return
|
||||
|
||||
load_format = self.load_format.lower()
|
||||
self.load_format = LoadFormat(load_format)
|
||||
|
||||
rocm_not_supported_load_format: List[str] = []
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f for f in LoadFormat.__members__
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format '{load_format}' is not supported in ROCm. "
|
||||
f"Supported load formats are "
|
||||
f"{rocm_supported_load_format}")
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution.
|
||||
|
||||
@ -699,8 +696,6 @@ class SpeculativeConfig:
|
||||
tokenizer=target_model_config.tokenizer,
|
||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||
trust_remote_code=target_model_config.trust_remote_code,
|
||||
download_dir=target_model_config.download_dir,
|
||||
load_format=target_model_config.load_format,
|
||||
dtype=target_model_config.dtype,
|
||||
seed=target_model_config.seed,
|
||||
revision=draft_revision,
|
||||
@ -887,65 +882,6 @@ class VisionLanguageConfig:
|
||||
f"{[x.name for x in cls.ImageInputType]}.") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig:
|
||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
||||
str, bytes, os.PathLike, int]
|
||||
vllm_tensorized: bool
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = 1
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
model_class: Optional[torch.nn.Module] = None
|
||||
hf_config: Optional[PretrainedConfig] = None
|
||||
dtype: Union[str, torch.dtype] = None
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
tensorizer_args = {
|
||||
"tensorizer_uri": self.tensorizer_uri,
|
||||
"vllm_tensorized": self.vllm_tensorized,
|
||||
"verify_hash": self.verify_hash,
|
||||
"num_readers": self.num_readers,
|
||||
"encryption_keyfile": self.encryption_keyfile,
|
||||
"s3_access_key_id": self.s3_access_key_id,
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
}
|
||||
return TensorizerArgs(**tensorizer_args)
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if (parallel_config.tensor_parallel_size > 1
|
||||
and self.tensorizer_uri is not None):
|
||||
raise ValueError(
|
||||
"Loading to multiple GPUs is not currently supported with "
|
||||
"vLLM-serialized models. Please set tensor_parallel_size=1."
|
||||
" or use a non-vLLM-serialized model, such as a "
|
||||
"serialized Hugging Face `PretrainedModel`.")
|
||||
|
||||
def verify_with_model_config(self, model_config) -> None:
|
||||
if (model_config.quantization is not None
|
||||
and self.tensorizer_uri is not None):
|
||||
from vllm.model_executor.tensorizer_loader import (
|
||||
tensorizer_warning)
|
||||
tensorizer_warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors.")
|
||||
|
||||
if (model_config.load_format != "tensorizer"
|
||||
and self.tensorizer_uri is not None):
|
||||
raise ValueError(
|
||||
"A tensorizer uri was passed for tensorizer loading, but the "
|
||||
f"load format was set to {model_config.load_format}. "
|
||||
"Please set the load format to 'tensorizer' to use "
|
||||
f"tensorizer args.")
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
@ -1105,11 +1041,11 @@ class EngineConfig:
|
||||
parallel_config: ParallelConfig
|
||||
scheduler_config: SchedulerConfig
|
||||
device_config: DeviceConfig
|
||||
load_config: LoadConfig
|
||||
lora_config: Optional[LoRAConfig]
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
speculative_config: Optional[SpeculativeConfig]
|
||||
decoding_config: Optional[DecodingConfig]
|
||||
tensorizer_config: Optional[TensorizerConfig]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
@ -1117,11 +1053,6 @@ class EngineConfig:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
if self.tensorizer_config:
|
||||
self.tensorizer_config.verify_with_parallel_config(
|
||||
self.parallel_config)
|
||||
self.tensorizer_config.verify_with_model_config(self.model_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
|
@ -1,15 +1,12 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import io
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import BinaryIO, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, TensorizerConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
TokenizerPoolConfig, VisionLanguageConfig)
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
from vllm.utils import str_to_int_tuple
|
||||
|
||||
|
||||
@ -60,17 +57,7 @@ class EngineArgs:
|
||||
ray_workers_use_nsight: bool = False
|
||||
num_gpu_blocks_override: Optional[int] = None
|
||||
num_lookahead_slots: int = 0
|
||||
|
||||
# Tensorizer configuration parameters
|
||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
|
||||
bytes, os.PathLike, int] = None
|
||||
vllm_tensorized: bool = False
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = 1
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
model_loader_extra_config: Optional[dict] = None
|
||||
|
||||
# Related to Vision-language models such as llava
|
||||
image_input_type: Optional[str] = None
|
||||
@ -429,7 +416,16 @@ class EngineArgs:
|
||||
default=None,
|
||||
help='The number of speculative tokens to sample from '
|
||||
'the draft model in speculative decoding')
|
||||
parser = TensorizerArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--model-loader-extra-config',
|
||||
type=str,
|
||||
default=EngineArgs.model_loader_extra_config,
|
||||
help='Extra config for model loader. '
|
||||
'This will be passed to the model loader '
|
||||
'corresponding to the chosen load_format. '
|
||||
'This should be a JSON string that will be '
|
||||
'parsed into a dictionary.')
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -444,11 +440,11 @@ class EngineArgs:
|
||||
device_config = DeviceConfig(self.device)
|
||||
model_config = ModelConfig(
|
||||
self.model, self.tokenizer, self.tokenizer_mode,
|
||||
self.trust_remote_code, self.download_dir, self.load_format,
|
||||
self.dtype, self.seed, self.revision, self.code_revision,
|
||||
self.tokenizer_revision, self.max_model_len, self.quantization,
|
||||
self.quantization_param_path, self.enforce_eager,
|
||||
self.max_context_len_to_capture, self.max_logprobs)
|
||||
self.trust_remote_code, self.dtype, self.seed, self.revision,
|
||||
self.code_revision, self.tokenizer_revision, self.max_model_len,
|
||||
self.quantization, self.quantization_param_path,
|
||||
self.enforce_eager, self.max_context_len_to_capture,
|
||||
self.max_logprobs)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
@ -492,15 +488,10 @@ class EngineArgs:
|
||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||
|
||||
tensorizer_config = TensorizerConfig(
|
||||
tensorizer_uri=self.tensorizer_uri,
|
||||
vllm_tensorized=self.vllm_tensorized,
|
||||
verify_hash=self.verify_hash,
|
||||
num_readers=self.num_readers,
|
||||
encryption_keyfile=self.encryption_keyfile,
|
||||
s3_access_key_id=self.s3_access_key_id,
|
||||
s3_secret_access_key=self.s3_secret_access_key,
|
||||
s3_endpoint=self.s3_endpoint,
|
||||
load_config = LoadConfig(
|
||||
load_format=self.load_format,
|
||||
download_dir=self.download_dir,
|
||||
model_loader_extra_config=self.model_loader_extra_config,
|
||||
)
|
||||
|
||||
if self.image_input_type:
|
||||
@ -530,8 +521,8 @@ class EngineArgs:
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
speculative_config=speculative_config,
|
||||
decoding_config=decoding_config,
|
||||
tensorizer_config=tensorizer_config)
|
||||
load_config=load_config,
|
||||
decoding_config=decoding_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -4,9 +4,9 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TensorizerConfig,
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@ -72,11 +72,11 @@ class LLMEngine:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
decoding_config: Optional[DecodingConfig],
|
||||
tensorizer_config: Optional[TensorizerConfig],
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
@ -92,8 +92,8 @@ class LLMEngine:
|
||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||
f"dtype={model_config.dtype}, "
|
||||
f"max_seq_len={model_config.max_model_len}, "
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"download_dir={load_config.download_dir!r}, "
|
||||
f"load_format={load_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"disable_custom_all_reduce="
|
||||
f"{parallel_config.disable_custom_all_reduce}, "
|
||||
@ -114,8 +114,8 @@ class LLMEngine:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.speculative_config = speculative_config
|
||||
self.load_config = load_config
|
||||
self.decoding_config = decoding_config or DecodingConfig()
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
self._init_tokenizer()
|
||||
@ -131,7 +131,7 @@ class LLMEngine:
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
speculative_config=speculative_config,
|
||||
tensorizer_config=tensorizer_config,
|
||||
load_config=load_config,
|
||||
)
|
||||
|
||||
self._initialize_kv_caches()
|
||||
@ -271,9 +271,6 @@ class LLMEngine:
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
if self.tensorizer_config:
|
||||
self.tensorizer_config.verify_with_parallel_config(
|
||||
self.parallel_config)
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
|
@ -40,6 +40,7 @@ class CPUExecutor(ExecutorBase):
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
@ -1,9 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
TensorizerConfig, VisionLanguageConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, VisionLanguageConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
|
||||
@ -23,20 +23,20 @@ class ExecutorBase(ABC):
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
tensorizer_config: Optional[TensorizerConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.speculative_config = speculative_config
|
||||
self.tensorizer_config = tensorizer_config
|
||||
|
||||
self._init_executor()
|
||||
|
||||
|
@ -35,12 +35,12 @@ class GPUExecutor(ExecutorBase):
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
tensorizer_config=self.tensorizer_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self.driver_worker.init_device()
|
||||
|
@ -147,6 +147,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
load_config = copy.deepcopy(self.load_config)
|
||||
device_config = copy.deepcopy(self.device_config)
|
||||
lora_config = copy.deepcopy(self.lora_config)
|
||||
cache_config = copy.deepcopy(self.cache_config)
|
||||
@ -165,12 +166,12 @@ class RayGPUExecutor(ExecutorBase):
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
tensorizer_config=self.tensorizer_config,
|
||||
))
|
||||
|
||||
# Initialize the driver worker with the Worker class.
|
||||
@ -187,7 +188,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
tensorizer_config=self.tensorizer_config,
|
||||
load_config=self.load_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
|
@ -1,128 +0,0 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import DeviceConfig, ModelConfig
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
|
||||
from vllm.model_executor.tensorizer_loader import (
|
||||
ParameterizedLoadFormat, is_vllm_serialized_tensorizer,
|
||||
load_with_tensorizer)
|
||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||
initialize_dummy_weights)
|
||||
|
||||
_VISION_MODEL_CLASSES = [
|
||||
LlavaForConditionalGeneration,
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def _get_model_architecture(
|
||||
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
if (model_config.quantization is not None
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return _get_model_architecture(model_config)[1]
|
||||
|
||||
|
||||
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
||||
**kwargs) -> nn.Module:
|
||||
lora_config = kwargs.get("lora_config", None)
|
||||
vision_language_config = kwargs.get("vision_language_config", None)
|
||||
tensorizer_config = kwargs.get("tensorizer_config", None)
|
||||
model_class = _get_model_architecture(model_config)[0]
|
||||
|
||||
# Get the (maybe quantized) linear method.
|
||||
linear_method = None
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} is not "
|
||||
"supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}.")
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
|
||||
linear_method = quant_config.get_linear_method()
|
||||
|
||||
with _set_default_torch_dtype(model_config.dtype):
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
extra_kwargs = {}
|
||||
if hasattr(model_class, "supported_lora_modules"):
|
||||
extra_kwargs["lora_config"] = lora_config
|
||||
elif lora_config:
|
||||
raise ValueError(
|
||||
f"Model {model_class.__name__} does not support LoRA, "
|
||||
"but LoRA is enabled. Support for this model may "
|
||||
"be added in the future. If this is important to you, "
|
||||
"please open an issue on github.")
|
||||
elif model_class in _VISION_MODEL_CLASSES:
|
||||
extra_kwargs["vision_language_config"] = vision_language_config
|
||||
|
||||
with torch.device(device_config.device):
|
||||
if (model_config.load_format == "tensorizer"
|
||||
and is_vllm_serialized_tensorizer(tensorizer_config)):
|
||||
extra_kwargs["linear_method"] = linear_method
|
||||
tensorizer_config.model_class = model_class
|
||||
tensorizer_config.hf_config = model_config.hf_config
|
||||
tensorizer_config.dtype = model_config.dtype
|
||||
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
|
||||
return model.eval()
|
||||
model = model_class(config=model_config.hf_config,
|
||||
linear_method=linear_method,
|
||||
**extra_kwargs)
|
||||
if model_config.load_format == "dummy":
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
else:
|
||||
# Load the weights from the cached or downloaded files.
|
||||
if model_config.load_format == "tensorizer":
|
||||
# Provide a dynamic load format for `model.load_weights`
|
||||
# to retain tensorizer args from CLI.
|
||||
model_config.load_format = ParameterizedLoadFormat(
|
||||
model_config.load_format)
|
||||
model_config.load_format.params = (
|
||||
tensorizer_config._construct_tensorizer_args())
|
||||
|
||||
model.load_weights(
|
||||
model_config.model,
|
||||
model_config.download_dir,
|
||||
model_config.load_format,
|
||||
model_config.revision,
|
||||
)
|
||||
return model.eval()
|
30
vllm/model_executor/model_loader/__init__.py
Normal file
30
vllm/model_executor/model_loader/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
||||
get_model_loader)
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name, get_model_architecture)
|
||||
|
||||
|
||||
def get_model(
|
||||
*, model_config: ModelConfig, load_config: LoadConfig,
|
||||
device_config: DeviceConfig, parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
||||
loader = get_model_loader(load_config)
|
||||
return loader.load_model(model_config=model_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_model", "get_model_loader", "BaseModelLoader",
|
||||
"get_architecture_class_name", "get_model_architecture"
|
||||
]
|
354
vllm/model_executor/model_loader/loader.py
Normal file
354
vllm/model_executor/model_loader/loader.py
Normal file
@ -0,0 +1,354 @@
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
import glob
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
|
||||
Type)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
|
||||
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
|
||||
tensorizer_weights_iterator)
|
||||
from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
||||
set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf, filter_files_not_needed_for_inference,
|
||||
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
|
||||
_VISION_MODEL_CLASSES = [
|
||||
LlavaForConditionalGeneration,
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_linear_method(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> Optional["LinearMethodBase"]:
|
||||
"""Get the (maybe quantized) linear method."""
|
||||
linear_method = None
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} is not "
|
||||
"supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}.")
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
|
||||
linear_method = quant_config.get_linear_method()
|
||||
return linear_method
|
||||
|
||||
|
||||
def _get_model_initialization_kwargs(
|
||||
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
) -> Dict[str, Any]:
|
||||
"""Get extra kwargs for model initialization."""
|
||||
extra_kwargs = {}
|
||||
if hasattr(model_class, "supported_lora_modules"):
|
||||
extra_kwargs["lora_config"] = lora_config
|
||||
elif lora_config:
|
||||
raise ValueError(
|
||||
f"Model {model_class.__name__} does not support LoRA, "
|
||||
"but LoRA is enabled. Support for this model may "
|
||||
"be added in the future. If this is important to you, "
|
||||
"please open an issue on github.")
|
||||
elif model_class in _VISION_MODEL_CLASSES:
|
||||
extra_kwargs["vision_language_config"] = vision_language_config
|
||||
return extra_kwargs
|
||||
|
||||
|
||||
def _initialize_model(
|
||||
model_config: ModelConfig, load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
linear_method = _get_linear_method(model_config, load_config)
|
||||
|
||||
return model_class(config=model_config.hf_config,
|
||||
linear_method=linear_method,
|
||||
**_get_model_initialization_kwargs(
|
||||
model_class, lora_config, vision_language_config))
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
"""Base class for model loaders."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
self.load_config = load_config
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
...
|
||||
|
||||
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def _maybe_download_from_modelscope(
|
||||
self, model: str, revision: Optional[str]) -> Optional[str]:
|
||||
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
||||
|
||||
Returns the path to the downloaded model, or None if the model is not
|
||||
downloaded from ModelScope."""
|
||||
if VLLM_USE_MODELSCOPE:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
if not os.path.exists(model):
|
||||
model_path = snapshot_download(
|
||||
model_id=model,
|
||||
cache_dir=self.load_config.download_dir,
|
||||
revision=revision)
|
||||
else:
|
||||
model_path = model
|
||||
return model_path
|
||||
return None
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = self._maybe_download_from_modelscope(
|
||||
model_name_or_path, revision) or model_name_or_path
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == LoadFormat.AUTO:
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == LoadFormat.SAFETENSORS:
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == LoadFormat.PT:
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == LoadFormat.NPCACHE:
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if not use_safetensors:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(
|
||||
hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_name_or_path: str, revision: Optional[str],
|
||||
fall_back_to_pt: bool
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision, fall_back_to_pt)
|
||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
return np_cache_weights_iterator(model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder, hf_weights_files)
|
||||
if use_safetensors:
|
||||
return safetensors_weights_iterator(hf_weights_files)
|
||||
return pt_weights_iterator(hf_weights_files)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=getattr(
|
||||
model,
|
||||
"fall_back_to_pt_during_load",
|
||||
True)), )
|
||||
return model.eval()
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
"""Model loader that will set model weights to random values."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config)
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
return model.eval()
|
||||
|
||||
|
||||
class TensorizerLoader(BaseModelLoader):
|
||||
"""Model loader using CoreWeave's tensorizer library."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
||||
self.tensorizer_config = load_config.model_loader_extra_config
|
||||
else:
|
||||
self.tensorizer_config = TensorizerConfig(
|
||||
**load_config.model_loader_extra_config)
|
||||
|
||||
def _verify_config(self, model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig):
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
def _load_model_unserialized(
|
||||
self, model_config: ModelConfig, device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
) -> nn.Module:
|
||||
"""Load an unserialized model with tensorizer.
|
||||
|
||||
Unserialized here means "not serialized with tensorizer". This
|
||||
should still be faster than default HuggingFace loading, but will
|
||||
be slower than loading a tensorizer-serialized model.
|
||||
"""
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
|
||||
def _load_model_serialized(
|
||||
self, model_config: ModelConfig, device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer.
|
||||
|
||||
See the examples/tensorize_vllm_model.py example "
|
||||
script for serializing vLLM models."""
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
linear_method = _get_linear_method(model_config,
|
||||
self.load_config)
|
||||
extra_kwargs = _get_model_initialization_kwargs(
|
||||
model_class, lora_config, vision_language_config)
|
||||
extra_kwargs["linear_method"] = linear_method
|
||||
|
||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||
tensorizer_config.model_class = model_class
|
||||
tensorizer_config.hf_config = model_config.hf_config
|
||||
tensorizer_config.dtype = model_config.dtype
|
||||
|
||||
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
|
||||
return model.eval()
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
if is_vllm_serialized_tensorizer(self.tensorizer_config):
|
||||
return self._load_model_serialized(model_config, device_config,
|
||||
lora_config,
|
||||
vision_language_config)
|
||||
return self._load_model_unserialized(model_config, device_config,
|
||||
lora_config,
|
||||
vision_language_config)
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
if isinstance(load_config.load_format, type):
|
||||
return load_config.load_format(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.DUMMY:
|
||||
return DummyModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.TENSORIZER:
|
||||
return TensorizerLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
@ -4,20 +4,20 @@ import io
|
||||
import os
|
||||
import time
|
||||
import typing
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Generator, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import TensorizerConfig
|
||||
from vllm.config import ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
|
||||
tensorizer_load_fail = False
|
||||
tensorizer_load_fail = None
|
||||
|
||||
try:
|
||||
from tensorizer import (DecryptionParams, EncryptionParams,
|
||||
@ -25,51 +25,78 @@ try:
|
||||
from tensorizer.stream_io import open_stream
|
||||
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||||
no_init_or_tensor)
|
||||
except ImportError:
|
||||
tensorizer_load_fail = True
|
||||
except ImportError as e:
|
||||
tensorizer_load_fail = e
|
||||
|
||||
__all__ = [
|
||||
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
|
||||
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
|
||||
'no_init_or_tensor'
|
||||
'no_init_or_tensor', 'TensorizerConfig'
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig:
|
||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
||||
str, bytes, os.PathLike, int]
|
||||
vllm_tensorized: bool
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = 1
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
model_class: Optional[Type[torch.nn.Module]] = None
|
||||
hf_config: Optional[PretrainedConfig] = None
|
||||
dtype: Optional[Union[str, torch.dtype]] = None
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
tensorizer_args = {
|
||||
"tensorizer_uri": self.tensorizer_uri,
|
||||
"vllm_tensorized": self.vllm_tensorized,
|
||||
"verify_hash": self.verify_hash,
|
||||
"num_readers": self.num_readers,
|
||||
"encryption_keyfile": self.encryption_keyfile,
|
||||
"s3_access_key_id": self.s3_access_key_id,
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
}
|
||||
return TensorizerArgs(**tensorizer_args)
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if (parallel_config.tensor_parallel_size > 1
|
||||
and self.tensorizer_uri is not None):
|
||||
raise ValueError(
|
||||
"Loading to multiple GPUs is not currently supported with "
|
||||
"vLLM-serialized models. Please set tensor_parallel_size=1."
|
||||
" or use a non-vLLM-serialized model, such as a "
|
||||
"serialized Hugging Face `PretrainedModel`.")
|
||||
|
||||
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
|
||||
if (model_config.quantization is not None
|
||||
and self.tensorizer_uri is not None):
|
||||
logger.warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors.")
|
||||
|
||||
|
||||
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
|
||||
**extra_kwargs) -> nn.Module:
|
||||
tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
|
||||
return tensorizer.deserialize()
|
||||
|
||||
|
||||
def tensorizer_warning(message: str):
|
||||
return warnings.warn(message, category=PerformanceWarning, stacklevel=2)
|
||||
|
||||
|
||||
def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
|
||||
if tensorizer_config is None:
|
||||
return False
|
||||
return tensorizer_config.vllm_tensorized
|
||||
|
||||
|
||||
class ParameterizedLoadFormat(str):
|
||||
__slots__ = "params"
|
||||
|
||||
|
||||
class PerformanceWarning(UserWarning):
|
||||
|
||||
def __str__(self):
|
||||
return (f"{super().__str__()}"
|
||||
" (set the VLLM_SILENCE_PERFORMANCE_WARNINGS"
|
||||
" environment variable to hide this)")
|
||||
|
||||
|
||||
if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower()
|
||||
not in ("", "0", "n", "no", "off", "disable")):
|
||||
warnings.simplefilter("ignore", category=PerformanceWarning)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerArgs:
|
||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
||||
@ -219,11 +246,17 @@ class TensorizerAgent:
|
||||
behavior of the TensorDeserializer when loading tensors from a serialized
|
||||
model. For deserializations of HuggingFace models, TensorDeserializer is
|
||||
instead used as an iterator directly in the func hf_model_weights_iterator
|
||||
in vllm/model_executor/weight_utils.py
|
||||
in vllm/model_executor/model_loader/weight_utils.py
|
||||
"""
|
||||
|
||||
def __init__(self, tensorizer_config: TensorizerConfig,
|
||||
linear_method: LinearMethodBase, **extra_kwargs):
|
||||
if tensorizer_load_fail is not None:
|
||||
raise ImportError(
|
||||
"Tensorizer is not installed. Please install tensorizer "
|
||||
"to use this feature with `pip install vllm[tensorizer]`."
|
||||
) from tensorizer_load_fail
|
||||
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.tensorizer_args = (
|
||||
self.tensorizer_config._construct_tensorizer_args())
|
||||
@ -234,11 +267,6 @@ class TensorizerAgent:
|
||||
self.linear_method = linear_method
|
||||
self.model = self._init_model()
|
||||
|
||||
if tensorizer_load_fail:
|
||||
raise ImportError(
|
||||
"Tensorizer is not installed. Please install tensorizer "
|
||||
"to use this feature with `pip install vllm[tensorizer]`.")
|
||||
|
||||
def _init_model(self):
|
||||
model_args = self.tensorizer_config.hf_config
|
||||
model_args.torch_dtype = self.tensorizer_config.dtype
|
||||
@ -313,3 +341,23 @@ class TensorizerAgent:
|
||||
self._check_tensors_on_meta_device()
|
||||
self._resize_lora_embeddings()
|
||||
return self.model.eval()
|
||||
|
||||
|
||||
def tensorizer_weights_iterator(
|
||||
tensorizer_args: "TensorizerArgs"
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
logger.warning(
|
||||
"Deserializing HuggingFace models is not optimized for "
|
||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||
"Consider deserializing a vLLM model instead for faster "
|
||||
"load times. See the examples/tensorize_vllm_model.py example "
|
||||
"script for serializing vLLM models.")
|
||||
|
||||
deserializer_args = tensorizer_args.deserializer_params
|
||||
stream_params = tensorizer_args.stream_params
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
||||
with TensorDeserializer(stream, **deserializer_args,
|
||||
device="cpu") as state:
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
del state
|
40
vllm/model_executor/model_loader/utils.py
Normal file
40
vllm/model_executor/model_loader/utils.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
if (model_config.quantization is not None
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
@ -4,8 +4,9 @@ import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Generator, Iterable, List, Optional, Tuple
|
||||
|
||||
import filelock
|
||||
import huggingface_hub.constants
|
||||
@ -15,7 +16,7 @@ from huggingface_hub import HfFileSystem, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
@ -27,8 +28,7 @@ logger = init_logger(__name__)
|
||||
# can share the same lock without error.
|
||||
# lock files in the temp directory will be automatically deleted when the
|
||||
# system reboots, so users will not complain about annoying lock files
|
||||
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
|
||||
'TEMP') or os.environ.get('TMP') or "/tmp/"
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
|
||||
def enable_hf_transfer():
|
||||
@ -46,7 +46,7 @@ def enable_hf_transfer():
|
||||
enable_hf_transfer()
|
||||
|
||||
|
||||
class Disabledtqdm(tqdm):
|
||||
class DisabledTqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
@ -114,7 +114,8 @@ def convert_bin_to_safetensor_file(
|
||||
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
||||
def get_quant_config(model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> QuantizationConfig:
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
@ -125,12 +126,12 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, model_config.download_dir):
|
||||
with get_lock(model_name_or_path, load_config.download_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=model_config.download_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
cache_dir=load_config.download_dir,
|
||||
tqdm_class=DisabledTqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
@ -153,169 +154,127 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def prepare_hf_model_weights(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
fall_back_to_pt: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
# Download model weights from huggingface.
|
||||
is_local = os.path.isdir(model_name_or_path) \
|
||||
and load_format != "tensorizer"
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npcache":
|
||||
allow_patterns = ["*.bin"]
|
||||
elif load_format == "tensorizer":
|
||||
allow_patterns = ["*.tensors"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
def download_weights_from_hf(model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
allow_patterns: List[str],
|
||||
revision: Optional[str] = None) -> str:
|
||||
"""Download model weights from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
allow_patterns (List[str]): The allowed patterns for the
|
||||
weight files. Files matched by any of the patterns will be
|
||||
downloaded.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
Returns:
|
||||
str: The path to the downloaded model weights.
|
||||
"""
|
||||
# Before we download we look at that is available:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
|
||||
if not is_local and load_format != "tensorizer":
|
||||
# Before we download we look at that is available:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
|
||||
# depending on what is available we download different things
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
|
||||
logger.info(f"Using model weights format {allow_patterns}")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm,
|
||||
revision=revision)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
hf_weights_files: List[str] = []
|
||||
# depending on what is available we download different things
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
if not use_safetensors:
|
||||
# Exclude files that are not needed for inference.
|
||||
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
|
||||
if load_format == "tensorizer":
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
logger.info(f"Using model weights format {allow_patterns}")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=DisabledTqdm,
|
||||
revision=revision)
|
||||
return hf_folder
|
||||
|
||||
|
||||
def hf_model_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: Union[Tuple, str] = "auto",
|
||||
revision: Optional[str] = None,
|
||||
fall_back_to_pt: Optional[bool] = True,
|
||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
load_format=load_format,
|
||||
fall_back_to_pt=fall_back_to_pt,
|
||||
revision=revision)
|
||||
def filter_files_not_needed_for_inference(
|
||||
hf_weights_files: List[str]) -> List[str]:
|
||||
"""
|
||||
Exclude files that are not needed for inference.
|
||||
|
||||
if load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
"""
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
return hf_weights_files
|
||||
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# dumping the same model weights to numpy at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names = []
|
||||
for bin_file in hf_weights_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, "w") as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
with open(weight_names_file, "r") as f:
|
||||
weight_names = json.load(f)
|
||||
def np_cache_weights_iterator(
|
||||
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model np files.
|
||||
|
||||
for name in weight_names:
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "rb") as f:
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
elif load_format == "tensorizer":
|
||||
from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
|
||||
open_stream,
|
||||
tensorizer_warning)
|
||||
tensorizer_args = load_format.params
|
||||
tensorizer_warning(
|
||||
"Deserializing HuggingFace models is not optimized for "
|
||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||
"Consider deserializing a vLLM model instead for faster "
|
||||
"load times. See the examples/tensorize_vllm_model.py example "
|
||||
"script for serializing vLLM models.")
|
||||
Will dump the model weights to numpy files if they are not already dumped.
|
||||
"""
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# dumping the same model weights to numpy at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names = []
|
||||
for bin_file in hf_weights_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, "w") as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
deserializer_args = tensorizer_args.deserializer_params
|
||||
stream_params = tensorizer_args.stream_params
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
||||
with TensorDeserializer(stream, **deserializer_args,
|
||||
device="cpu") as state:
|
||||
for name, param in state.items():
|
||||
with open(weight_names_file, "r") as f:
|
||||
weight_names = json.load(f)
|
||||
|
||||
for name in weight_names:
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "rb") as f:
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
for st_file in hf_weights_files:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
for bin_file in hf_weights_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
del state
|
||||
elif use_safetensors:
|
||||
for st_file in hf_weights_files:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
else:
|
||||
for bin_file in hf_weights_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def kv_cache_scales_loader(
|
@ -19,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -340,19 +339,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if name == "lm_head.weight":
|
||||
|
@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -35,9 +35,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -298,14 +297,9 @@ class BloomForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if name == "lm_head.weight":
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
|
@ -2,7 +2,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -22,9 +22,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
@ -370,14 +369,9 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_pos_emb.inv_freq" in name:
|
||||
continue
|
||||
if "word_embeddings" in name:
|
||||
|
@ -20,7 +20,7 @@
|
||||
|
||||
# This file is based on the LLama model definition file in transformers
|
||||
"""PyTorch Cohere model."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -41,10 +41,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -335,13 +334,7 @@ class CohereForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -352,8 +345,7 @@ class CohereForCausalLM(nn.Module):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params = set()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -18,10 +18,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
@ -391,20 +390,13 @@ class DbrxForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
expert_params_mapping = [(
|
||||
"ws" if weight_name in ["w1", "v1"] else "w2s",
|
||||
f"experts.mlp.{weight_name}",
|
||||
) for weight_name in ["w1", "v1", "w2"]]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
@ -23,16 +23,15 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only DeciLM model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
|
||||
|
||||
class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
@ -65,11 +64,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
linear_method=linear_method,
|
||||
lora_config=lora_config)
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -79,8 +74,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
|
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Deepseek model."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -44,9 +44,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -316,6 +315,8 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
|
||||
class DeepseekModel(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@ -395,11 +396,7 @@ class DeepseekForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -410,12 +407,7 @@ class DeepseekForCausalLM(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path,
|
||||
cache_dir,
|
||||
load_format,
|
||||
revision,
|
||||
fall_back_to_pt=False):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
|
@ -19,7 +19,7 @@
|
||||
"""PyTorch Falcon model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
@ -399,11 +398,7 @@ class FalconForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
if self.config.new_decoder_architecture:
|
||||
total_num_kv_heads = self.config.num_kv_heads
|
||||
@ -413,8 +408,7 @@ class FalconForCausalLM(nn.Module):
|
||||
total_num_kv_heads = total_num_heads
|
||||
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if name == "lm_head.weight":
|
||||
# Falcon uses tied embeddings.
|
||||
continue
|
||||
|
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,9 +36,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -346,11 +345,7 @@ class GemmaForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -361,8 +356,7 @@ class GemmaForCausalLM(nn.Module):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params = set()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
|
@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -34,9 +34,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -239,14 +238,9 @@ class GPT2LMHeadModel(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -35,9 +35,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -260,14 +259,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
if ".attn.bias" in name:
|
||||
|
@ -16,7 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-J model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -34,9 +34,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -248,11 +247,7 @@ class GPTJForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -262,8 +257,7 @@ class GPTJForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
|
@ -16,7 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -34,9 +34,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -262,14 +261,9 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
|
@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -18,9 +18,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -274,19 +273,14 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "w1", 0),
|
||||
("gate_up_proj", "w3", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
|
@ -20,7 +20,7 @@
|
||||
"""Inference-only Jais model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,9 +36,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import JAISConfig
|
||||
|
||||
@ -303,16 +302,9 @@ class JAISLMHeadModel(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
|
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -42,10 +42,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
kv_cache_scales_loader)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.utils import is_hip
|
||||
|
||||
@ -376,11 +375,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -390,8 +385,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -13,10 +13,9 @@ from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
@ -198,11 +197,7 @@ class LlavaForConditionalGeneration(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# only doing this for language model part for now.
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@ -213,8 +208,7 @@ class LlavaForConditionalGeneration(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
|
@ -22,7 +22,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -45,10 +45,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -472,11 +471,7 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -493,8 +488,7 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
|
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -43,10 +43,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -319,6 +318,8 @@ class MixtralModel(nn.Module):
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -393,11 +394,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -414,12 +411,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path,
|
||||
cache_dir,
|
||||
load_format,
|
||||
revision,
|
||||
fall_back_to_pt=False):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
|
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -43,9 +43,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -327,6 +326,7 @@ class MixtralModel(nn.Module):
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -366,11 +366,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -379,12 +375,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path,
|
||||
cache_dir,
|
||||
load_format,
|
||||
revision,
|
||||
fall_back_to_pt=False):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
|
@ -1,7 +1,7 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||
import math
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -18,9 +18,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
@ -284,14 +283,9 @@ class MPTForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
@ -36,7 +36,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
# this model must need this dependency
|
||||
@ -56,9 +56,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -348,16 +347,9 @@ class OLMoForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
# attention
|
||||
if ".att" in name:
|
||||
name = name.replace(".att", ".attn.att")
|
||||
|
@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OPT model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -35,9 +35,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -315,11 +314,7 @@ class OPTForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -327,8 +322,7 @@ class OPTForCausalLM(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
if name.startswith("decoder."):
|
||||
|
@ -4,7 +4,7 @@
|
||||
# Copyright (c) OrionStar Inc.
|
||||
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
|
||||
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -22,9 +22,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -280,11 +279,7 @@ class OrionForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -294,8 +289,7 @@ class OrionForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
|
@ -35,7 +35,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -53,9 +53,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -265,11 +264,7 @@ class PhiForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -278,8 +273,7 @@ class PhiForCausalLM(nn.Module):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||
"""Inference-only QWen model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -23,9 +23,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -253,19 +252,14 @@ class QWenLMHeadModel(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "w2", 0),
|
||||
("gate_up_proj", "w1", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
|
@ -22,7 +22,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -42,9 +42,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -331,11 +330,7 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -345,8 +340,7 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
|
@ -22,7 +22,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -46,9 +46,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -366,6 +365,8 @@ class Qwen2MoeModel(nn.Module):
|
||||
|
||||
class Qwen2MoeForCausalLM(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@ -404,11 +405,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -419,12 +416,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path,
|
||||
cache_dir,
|
||||
load_format,
|
||||
revision,
|
||||
fall_back_to_pt=False):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
|
@ -19,7 +19,7 @@
|
||||
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
|
||||
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
|
||||
model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -37,9 +37,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -262,11 +261,7 @@ class StablelmForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -276,8 +271,7 @@ class StablelmForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Starcoder2 model."""
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,9 +36,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -274,11 +273,7 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -287,8 +282,7 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
|
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Xverse model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@ -331,11 +330,7 @@ class XverseForCausalLM(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
@ -344,8 +339,7 @@ class XverseForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
for name, loaded_weight in weights:
|
||||
if ("rotary_emb.inv_freq" in name
|
||||
or "rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
|
@ -1,8 +1,10 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from vllm.config import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
||||
@ -57,9 +59,26 @@ def get_tokenizer(
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
download_dir: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
"""Gets a tokenizer for the given model name via Huggingface/modelscope."""
|
||||
if VLLM_USE_MODELSCOPE:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
# Only set the tokenizer here, model will be downloaded on the workers.
|
||||
if not os.path.exists(tokenizer_name):
|
||||
tokenizer_path = snapshot_download(
|
||||
model_id=tokenizer_name,
|
||||
cache_dir=download_dir,
|
||||
revision=tokenizer_revision,
|
||||
# Ignore weights - we only need the tokenizer.
|
||||
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||
tokenizer_name = tokenizer_path
|
||||
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError(
|
||||
|
@ -3,8 +3,8 @@ from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
@ -26,6 +26,7 @@ class CPUModelRunner:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
@ -36,6 +37,7 @@ class CPUModelRunner:
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# model_config can be None in tests/samplers/test_sampler.py.
|
||||
@ -55,8 +57,10 @@ class CPUModelRunner:
|
||||
self.model_config.dtype if model_config is not None else None)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(self.model_config,
|
||||
self.device_config,
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
vision_language_config=None,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
|
@ -5,8 +5,8 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
@ -117,6 +117,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
@ -129,6 +130,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
@ -141,6 +143,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
@ -9,9 +9,8 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
||||
get_attn_backend)
|
||||
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, TensorizerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
||||
from vllm.distributed.device_communicators import (custom_all_reduce,
|
||||
pynccl_utils)
|
||||
@ -108,17 +107,17 @@ class ModelRunner:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
tensorizer_config: Optional[TensorizerConfig] = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lora_config = lora_config
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# model_config can be None in tests/samplers/test_sampler.py.
|
||||
@ -156,13 +155,13 @@ class ModelRunner:
|
||||
def load_model(self) -> None:
|
||||
with CudaMemoryProfiler() as m:
|
||||
self.model = get_model(
|
||||
self.model_config,
|
||||
self.device_config,
|
||||
model_config=self.model_config,
|
||||
device_config=self.device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
tensorizer_config=self.tensorizer_config,
|
||||
)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
|
@ -6,7 +6,7 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.neuron_model_loader import get_neuron_model
|
||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||
|
@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Set, Tuple
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, TensorizerConfig,
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
@ -38,12 +38,12 @@ class Worker(WorkerBase):
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
tensorizer_config: Optional[TensorizerConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
@ -55,7 +55,7 @@ class Worker(WorkerBase):
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
@ -70,11 +70,11 @@ class Worker(WorkerBase):
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
load_config=load_config,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
vision_language_config=vision_language_config,
|
||||
tensorizer_config=tensorizer_config,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
|
Loading…
x
Reference in New Issue
Block a user