[Core][VLM] Test registration for OOT multimodal models (#8717)

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Roger Wang 2024-10-04 10:38:25 -07:00 committed by GitHub
parent e5dc713c23
commit 26aa325f4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 227 additions and 49 deletions

View File

@ -85,16 +85,16 @@ When it comes to the linear layers, we provide the following options to parallel
* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
* :code:`RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.
* :code:`ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.
* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple `ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple :code:`ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
* :code:`QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.
Note that all the linear layers above take `linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
Note that all the linear layers above take :code:`linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
4. Implement the weight loading logic
-------------------------------------
You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class.
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for `MergedColumnParallelLinear` and `QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for :code:`MergedColumnParallelLinear` and :code:`QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
5. Register your model
----------------------
@ -114,6 +114,18 @@ Just add the following lines in your code:
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
If your model imports modules that initialize CUDA, consider instead lazy-importing it to avoid an error like :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`:
.. code-block:: python
from vllm import ModelRegistry
ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM")
.. important::
If your model is a multimodal model, make sure the model class implements the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
Read more about that :ref:`here <enabling_multimodal_inputs>`.
If you are running api server with :code:`vllm serve <args>`, you can wrap the entrypoint with the following code:
.. code-block:: python

33
find_cuda_init.py Normal file
View File

@ -0,0 +1,33 @@
import importlib
import traceback
from typing import Callable
from unittest.mock import patch
def find_cuda_init(fn: Callable[[], object]) -> None:
"""
Helper function to debug CUDA re-initialization errors.
If `fn` initializes CUDA, prints the stack trace of how this happens.
"""
from torch.cuda import _lazy_init
stack = None
def wrapper():
nonlocal stack
stack = traceback.extract_stack()
return _lazy_init()
with patch("torch.cuda._lazy_init", wrapper):
fn()
if stack is not None:
print("==== CUDA Initialized ====")
print("".join(traceback.format_list(stack)).strip())
print("==========================")
if __name__ == "__main__":
find_cuda_init(
lambda: importlib.import_module("vllm.model_executor.models.llava"))

View File

@ -879,15 +879,16 @@ def num_gpus_available():
temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
@pytest.fixture
def dummy_opt_path():
json_path = os.path.join(_dummy_path, "config.json")
if not os.path.exists(_dummy_path):
json_path = os.path.join(_dummy_opt_path, "config.json")
if not os.path.exists(_dummy_opt_path):
snapshot_download(repo_id="facebook/opt-125m",
local_dir=_dummy_path,
local_dir=_dummy_opt_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
@ -898,4 +899,23 @@ def dummy_opt_path():
config["architectures"] = ["MyOPTForCausalLM"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_path
return _dummy_opt_path
@pytest.fixture
def dummy_llava_path():
json_path = os.path.join(_dummy_llava_path, "config.json")
if not os.path.exists(_dummy_llava_path):
snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
local_dir=_dummy_llava_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
])
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
config["architectures"] = ["MyLlava"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_llava_path

View File

@ -21,7 +21,9 @@ def server():
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"2048",
"--max-num-seqs",
"5",
"--enforce-eager",
]

View File

@ -23,9 +23,16 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module")
def server():
args = [
"--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs",
"5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}"
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"5",
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -3,6 +3,7 @@ import os
import pytest
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from ..utils import fork_new_process_for_each_test
@ -29,3 +30,40 @@ def test_oot_registration(dummy_opt_path):
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
@fork_new_process_for_each_test
def test_oot_multimodal_registration(dummy_llava_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = [{
"prompt": "What's in the image?<image>",
"multi_modal_data": {
"image": image
},
}, {
"prompt": "Describe the image<image>",
"multi_modal_data": {
"image": image
},
}]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=dummy_llava_path,
load_format="dummy",
max_num_seqs=1,
trust_remote_code=True,
gpu_memory_utilization=0.98,
max_model_len=4096,
enforce_eager=True,
limit_mm_per_prompt={"image": 1})
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""

View File

@ -1,26 +1,14 @@
from typing import Optional
import torch
from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits
def register():
# register our dummy model
# Test directly passing the model
from .my_opt import MyOPTForCausalLM
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
# Test passing lazy model
if "MyLlava" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyLlava",
"vllm_add_dummy_model.my_llava:MyLlava")

View File

@ -0,0 +1,28 @@
from typing import Optional
import torch
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
dummy_data_for_llava,
get_max_llava_image_tokens,
input_processor_for_llava)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class MyLlava(LlavaForConditionalGeneration):
def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits

View File

@ -0,0 +1,19 @@
from typing import Optional
import torch
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits

View File

@ -183,6 +183,8 @@ class EngineArgs:
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
from vllm.plugins import load_general_plugins
load_general_plugins()
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

View File

@ -290,9 +290,6 @@ class LLMEngine:
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
load_general_plugins()
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config

View File

@ -125,9 +125,10 @@ _MODELS = {
**_CONDITIONAL_GENERATION_MODELS,
}
# Architecture -> type.
# Architecture -> type or (module, class).
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
_OOT_MODELS_LAZY: Dict[str, Tuple[str, str]] = {}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
@ -159,17 +160,24 @@ class ModelRegistry:
@staticmethod
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
module_relname, cls_name = _MODELS[model_arch]
return f"vllm.model_executor.models.{module_relname}", cls_name
if model_arch in _MODELS:
module_relname, cls_name = _MODELS[model_arch]
return f"vllm.model_executor.models.{module_relname}", cls_name
if model_arch in _OOT_MODELS_LAZY:
return _OOT_MODELS_LAZY[model_arch]
raise KeyError(model_arch)
@staticmethod
@lru_cache(maxsize=128)
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
try:
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
except KeyError:
return None
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
module = importlib.import_module(module_name)
module = importlib.import_module(mod_name)
return getattr(module, cls_name, None)
@staticmethod
@ -219,14 +227,35 @@ class ModelRegistry:
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
def register_model(model_arch: str, model_cls: Union[Type[nn.Module],
str]):
"""
Register an external model to be used in vLLM.
:code:`model_cls` can be either:
- A :class:`torch.nn.Module` class directly referencing the model.
- A string in the format :code:`<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if model_arch in _MODELS:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
model_cls)
_OOT_MODELS[model_arch] = model_cls
if isinstance(model_cls, str):
split_str = model_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
module_name, cls_name = split_str
_OOT_MODELS_LAZY[model_arch] = module_name, cls_name
else:
_OOT_MODELS[model_arch] = model_cls
@staticmethod
@lru_cache(maxsize=128)
@ -248,13 +277,16 @@ class ModelRegistry:
if model is not None:
return func(model)
if model_arch not in _MODELS and default is not None:
return default
try:
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
except KeyError:
if default is not None:
return default
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
raise
valid_name_characters = string.ascii_letters + string.digits + "._"
if any(s not in valid_name_characters for s in module_name):
if any(s not in valid_name_characters for s in mod_name):
raise ValueError(f"Unsafe module name detected for {model_arch}")
if any(s not in valid_name_characters for s in cls_name):
raise ValueError(f"Unsafe class name detected for {model_arch}")
@ -266,7 +298,7 @@ class ModelRegistry:
err_id = uuid.uuid4()
stmts = ";".join([
f"from {module_name} import {cls_name}",
f"from {mod_name} import {cls_name}",
f"from {func.__module__} import {func.__name__}",
f"assert {func.__name__}({cls_name}), '{err_id}'",
])