From 2a543d6efecc4e0fe391cbccb68d99ab42e37c33 Mon Sep 17 00:00:00 2001 From: Terry <149540247+tterrysun@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:55:45 -0800 Subject: [PATCH] Add LoRA support for Mixtral (#2831) * add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell --- tests/lora/conftest.py | 5 ++ tests/lora/test_lora_manager.py | 82 +++++++++++++---------- tests/lora/test_mixtral.py | 53 +++++++++++++++ vllm/lora/models.py | 96 +++++++++------------------ vllm/lora/worker_manager.py | 21 +++--- vllm/model_executor/model_loader.py | 2 +- vllm/model_executor/models/llama.py | 35 ++++++++-- vllm/model_executor/models/mistral.py | 27 +++++++- vllm/model_executor/models/mixtral.py | 40 ++++++++++- vllm/worker/model_runner.py | 11 ++- 10 files changed, 251 insertions(+), 121 deletions(-) create mode 100644 tests/lora/test_mixtral.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 163c3c70..0ca07153 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -121,6 +121,11 @@ def sql_lora_files(): return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") +@pytest.fixture(scope="session") +def mixtral_lora_files(): + return snapshot_download(repo_id="terrysun/mixtral-lora-adapter") + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 78a4a5bc..2d4fc085 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -11,25 +11,35 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA, RowParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager, +from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, LoRAMapping) from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) from vllm.model_executor.layers.linear import RowParallelLinear +EMBEDDING_MODULES = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", +} + +EMBEDDING_PADDING_MODULES = ["lm_head"] + def test_from_lora_tensors(sql_lora_files): tensors = load_file( os.path.join(sql_lora_files, "adapter_model.safetensors")) new_embeddings = load_file( os.path.join(sql_lora_files, "new_embeddings.safetensors")) - lora_model = LoRAModel.from_lora_tensors(1, - 8, - 16, - tensors, - "cuda", - embeddings=new_embeddings) + lora_model = LoRAModel.from_lora_tensors( + 1, + 8, + 16, + tensors, + "cuda", + embeddings=new_embeddings, + embedding_modules=EMBEDDING_MODULES, + embedding_padding_modules=EMBEDDING_PADDING_MODULES) for module_name, lora in lora_model.loras.items(): assert lora.module_name == module_name assert lora.rank == 8 @@ -90,14 +100,11 @@ def create_packed_lora( def test_replace_submodules(dist_init, dummy_model): model = dummy_model - manager = LoRAModelManager(model, - 1, - 1, - 1, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=8, - max_loras=8), - lora_target_modules=["dense1", "layer1.dense2"]) + model.supported_lora_modules = ["dense1", "layer1.dense2"] + model.packed_modules_mapping = {} + manager = LoRAModelManager( + model, 1, 1, 1, + LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8)) model = manager.model assert isinstance(model.get_submodule("dense1"), @@ -111,16 +118,14 @@ def test_replace_submodules(dist_init, dummy_model): def test_lora_model_manager(dist_init, dummy_model): model = dummy_model + model.supported_lora_modules = ["dense1", "dense2", "lm_head"] + model.packed_modules_mapping = {} model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) manager = LoRAModelManager( - model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), - lora_target_modules=["dense1", "dense2", "lm_head"]) + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_lora(model_lora1) assert manager.activate_lora(1) @@ -159,16 +164,14 @@ def test_lora_model_manager(dist_init, dummy_model): def test_lora_lru_cache_model_manager(dist_init, dummy_model): model = dummy_model + model.supported_lora_modules = ["dense1", "dense2", "lm_head"] + model.packed_modules_mapping = {} model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) manager = LRUCacheLoRAModelManager( - model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), - lora_target_modules=["dense1", "dense2", "lm_head"]) + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_lora(model_lora1) assert manager.activate_lora(1) @@ -212,14 +215,15 @@ def test_lru_lora_model_manager(dist_init, dummy_model): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager model = dummy_model + model.supported_lora_modules = ["dense1", "dense2", "lm_head"] + model.packed_modules_mapping = {} model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"]) manager = LRUCacheLoRAModelManager( model, 2, 2, 2, - LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), - ["dense1", "dense2", "lm_head"]) + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) @@ -289,8 +293,9 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, sql_lora_files): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) worker_lora_manager = LRUCacheWorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, - torch.device("cuda")) + 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - + lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), + EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) @@ -362,8 +367,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) worker_lora_manager = WorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, - torch.device("cuda")) + 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - + lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), + EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) @@ -428,6 +434,13 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, def test_packed_loras(dist_init, dummy_model_gate_up): model = dummy_model_gate_up + model.supported_lora_modules = ["gate_up_proj"] + model.packed_modules_mapping = { + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } model_lora = create_packed_lora( 1, model, @@ -443,8 +456,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up): manager = LoRAModelManager( model, 2, 2, 2, - LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), - ["gate_up_proj"]) + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2)) model = manager.model assert isinstance(model.get_submodule("gate_up_proj"), diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py new file mode 100644 index 00000000..e45fb92a --- /dev/null +++ b/tests/lora/test_mixtral.py @@ -0,0 +1,53 @@ +import pytest +import torch + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" + + +def do_sample(llm, lora_path: str, lora_id: int): + prompts = [ + "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", + "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", + "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("tp_size", [4]) +def test_mixtral_lora(mixtral_lora_files, tp_size): + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=tp_size, + worker_use_ray=True) + + expected_lora_output = [ + "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", + "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", + "inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", + ] + + assert do_sample(llm, mixtral_lora_files, + lora_id=1) == expected_lora_output + assert do_sample(llm, mixtral_lora_files, + lora_id=2) == expected_lora_output diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6c78c4a2..7386d21c 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,8 +4,7 @@ import logging import math import os import re -from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type, - Union) +from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type) import safetensors.torch import torch @@ -20,36 +19,6 @@ from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule logger = logging.getLogger(__name__) -# TODO: The mappings below should be moved to individual model classes. - -PACKED_MODULES_CFG = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], -} - -TARGET_MODULES_QKV = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", -] - -EMBEDDING_MODULES = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", -} - -EMBEDDING_PADDING_MODULES = ["lm_head"] - _GLOBAL_LORA_ID = 0 @@ -169,6 +138,8 @@ class LoRAModel: dtype: Optional[torch.dtype] = None, embeddings: Optional[Dict[str, torch.Tensor]] = None, target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" pin_memory = str(device) == "cpu" and not in_wsl() @@ -179,11 +150,11 @@ class LoRAModel: lora_embeddings_tensor = None if embeddings: embeddings_module = next( - (k for k in EMBEDDING_MODULES if k in module_name), + (k for k in embedding_modules if k in module_name), None) if embeddings_module: lora_embeddings_tensor = embeddings[ - EMBEDDING_MODULES[embeddings_module]].to( + embedding_modules[embeddings_module]].to( device=device, dtype=dtype) if pin_memory: lora_embeddings_tensor = ( @@ -201,7 +172,7 @@ class LoRAModel: loras[module_name].lora_b = tensor.to(device=device, dtype=dtype).t() if any(name in module_name - for name in EMBEDDING_PADDING_MODULES + for name in embedding_padding_modules ) and target_embedding_padding is not None: lora_b = loras[module_name].lora_b assert target_embedding_padding >= lora_b.shape[1] @@ -218,12 +189,15 @@ class LoRAModel: @classmethod def from_local_checkpoint( - cls, - lora_dir: str, - lora_model_id: Optional[int] = None, - device: str = "cuda", - dtype: Optional[torch.dtype] = None, - target_embedding_padding: Optional[int] = None) -> "LoRAModel": + cls, + lora_dir: str, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint.""" lora_config_path = os.path.join(lora_dir, "adapter_config.json") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") @@ -260,6 +234,8 @@ class LoRAModel: dtype=dtype, embeddings=embeddings, target_embedding_padding=target_embedding_padding, + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules, ) @@ -273,8 +249,6 @@ class LoRAModelManager: max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, - lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, - packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, ): """Create a LoRAModelManager and adapter for a given model. @@ -286,13 +260,6 @@ class LoRAModelManager: in a single batch. vocab_size: the vocab size of the model. lora_config: the LoRA configuration. - lora_target_modules: the target modules patterns to be adapted. - Support both single module name and a list of module names. - packed_modules_mapping: the mapping for packed modules. vLLM - packs some modules into one module, e.g., qkv_proj - is packed of q_proj, k_proj, and v_proj. These modules - have a single layer in the original model, but they are split - into multiple layers in the adapted model. """ self.lora_config = lora_config self.max_num_seqs = max_num_seqs @@ -320,11 +287,11 @@ class LoRAModelManager: self.indices_len = [None] * 4 self.model: nn.Module = model - self.lora_target_modules: List[str] = ([ - lora_target_modules - ] if isinstance(lora_target_modules, str) else lora_target_modules) - self.lora_target_modules = copy.deepcopy(lora_target_modules) - self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping) + if hasattr(self.model, "supported_lora_modules"): + self.supported_lora_modules = copy.deepcopy( + self.model.supported_lora_modules) + self.packed_modules_mapping = copy.deepcopy( + self.model.packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {} self._registered_loras: Dict[int, LoRAModel] = {} @@ -468,7 +435,11 @@ class LoRAModelManager: assert isinstance(module, BaseLayerWithLoRA) self.modules[module_name] = module - def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: + def create_dummy_lora( + self, + lora_id: int, + rank: int, + embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): @@ -477,7 +448,7 @@ class LoRAModelManager: continue parts = module_name.split(".") if module_name not in self.packed_modules: - if parts[-1] in EMBEDDING_MODULES: + if parts[-1] in embedding_modules: input_dim = (module.base_layer.org_vocab_size + self.lora_config.lora_extra_vocab_size if hasattr(module.base_layer, "org_vocab_size") @@ -531,7 +502,7 @@ class LoRAModelManager: re.match( r".*\.{target_module}$".format(target_module=target_module), module_name) or target_module == module_name - for target_module in self.lora_target_modules) + for target_module in self.supported_lora_modules) def _register_packed_modules(self, module_full_name: str) -> None: parts = module_full_name.split(".") @@ -586,12 +557,9 @@ class LRUCacheLoRAModelManager(LoRAModelManager): max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, - lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, - packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, ): super().__init__(model, max_num_seqs, max_num_batched_tokens, - vocab_size, lora_config, lora_target_modules, - packed_modules_mapping) + vocab_size, lora_config) self._registered_loras: LoRALRUCache = LoRALRUCache( self.capacity, self.deactivate_lora) self._active_loras: LoRALRUCache = LoRALRUCache( @@ -637,11 +605,10 @@ def create_lora_manager( max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, - target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" - if not getattr(model, "supports_lora", False): + if not hasattr(model, "supported_lora_modules"): raise ValueError(f"Model {type(model)} is not supported for LoRA.") lora_manager = lora_manager_cls( model=model, @@ -649,6 +616,5 @@ def create_lora_manager( max_num_batched_tokens=max_num_batched_tokens, vocab_size=vocab_size, lora_config=lora_config, - lora_target_modules=target_modules, **kwargs) return lora_manager diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index a507c085..7e92bc93 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,10 +1,10 @@ import logging from abc import ABC, abstractmethod, abstractproperty -from typing import Any, List, Optional, Set, Type, Union +from typing import Any, Dict, List, Optional, Set, Type import torch -from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager, +from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) from vllm.lora.request import LoRARequest from vllm.lora.layers import LoRAMapping @@ -13,7 +13,7 @@ from vllm.config import LoRAConfig logger = logging.getLogger(__name__) -class WorkerLoRAManager(ABC): +class AbstractWorkerLoRAManager(ABC): """Abstract class for managing LoRA models on the worker side.""" def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, @@ -33,7 +33,6 @@ class WorkerLoRAManager(ABC): def create_lora_manager( self, model: torch.nn.Module, - target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, ) -> Any: ... @@ -63,7 +62,7 @@ class WorkerLoRAManager(ABC): ... -class WorkerLoRAManager(WorkerLoRAManager): +class WorkerLoRAManager(AbstractWorkerLoRAManager): """WorkerLoRAManager that manages LoRA models on the worker side. Every request, the requested LoRAs will be loaded (unless they are already @@ -78,10 +77,14 @@ class WorkerLoRAManager(WorkerLoRAManager): vocab_size: int, lora_config: LoRAConfig, device: torch.device, + embedding_modules: Dict[str, str], + embedding_padding_modules: List[str], lora_model_cls: Type[LoRAModel] = LoRAModel, ): self._lora_manager: Optional[LoRAModelManager] = None self._lora_model_cls = lora_model_cls + self.embedding_modules = embedding_modules + self.embedding_padding_modules = embedding_padding_modules super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device) @@ -92,13 +95,11 @@ class WorkerLoRAManager(WorkerLoRAManager): def create_lora_manager( self, model: torch.nn.Module, - target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, ) -> Any: lora_manager = create_lora_manager( model, max_num_seqs=self.max_num_seqs, max_num_batched_tokens=self.max_num_batched_tokens, - target_modules=target_modules, vocab_size=self.vocab_size, lora_config=self.lora_config, lora_manager_cls=self._lora_manager_cls, @@ -142,6 +143,8 @@ class WorkerLoRAManager(WorkerLoRAManager): dtype=self.lora_config.lora_dtype, target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, ) except Exception as e: raise RuntimeError( @@ -162,7 +165,7 @@ class WorkerLoRAManager(WorkerLoRAManager): return False return self._lora_manager.add_lora( self._lora_manager.create_dummy_lora(lora_request.lora_int_id, - rank)) + rank, self.embedding_modules)) def add_lora(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id in self.list_loras(): @@ -195,11 +198,9 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): def create_lora_manager( self, model: torch.nn.Module, - target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, ) -> Any: lora_manager = create_lora_manager( model, - target_modules=target_modules, lora_manager_cls=self._lora_manager_cls, max_num_seqs=self.max_num_seqs, vocab_size=self.vocab_size, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 4b1e13d9..ebe092b5 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -66,7 +66,7 @@ def get_model(model_config: ModelConfig, # Create a model instance. # The weights will be initialized as empty tensors. with torch.device(device_config.device): - if getattr(model_class, "supports_lora", False): + if hasattr(model_class, "supported_lora_modules"): model = model_class(model_config.hf_config, linear_method, lora_config) elif lora_config: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e5a1abeb..860a8f26 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -269,7 +269,32 @@ class LlamaModel(nn.Module): class LlamaForCausalLM(nn.Module): - supports_lora = True + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] def __init__( self, @@ -281,11 +306,11 @@ class LlamaForCausalLM(nn.Module): self.config = config self.linear_method = linear_method self.model = LlamaModel(config, linear_method, lora_config=lora_config) - unpadded_vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size if lora_config: - unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - unpadded_vocab_size, + self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE @@ -293,7 +318,7 @@ class LlamaForCausalLM(nn.Module): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) + self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 01cde678..2347ed75 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -265,7 +265,32 @@ class MistralModel(nn.Module): class MistralForCausalLM(nn.Module): - supports_lora = True + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] def __init__( self, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index aeb9d087..6cb1d849 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -27,6 +27,7 @@ import torch from torch import nn from transformers import MixtralConfig +from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.fused_moe import fused_moe @@ -38,7 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -292,6 +293,7 @@ class MixtralModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + org_num_embeddings=self.org_vocab_size, ) self.layers = nn.ModuleList([ MixtralDecoderLayer(config, linear_method=linear_method) @@ -318,18 +320,50 @@ class MixtralModel(nn.Module): class MixtralForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] def __init__( self, config: MixtralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method self.model = MixtralModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 62f75308..065d5899 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -86,11 +86,20 @@ class ModelRunner: vocab_size = self.model.config.vocab_size if self.lora_config: + assert hasattr( + self.model, "supported_lora_modules" + ) and self.model.supported_lora_modules, "Model does not support LoRA" + assert hasattr( + self.model, + "embedding_modules"), "Model does not have embedding_modules" + assert hasattr(self.model, "embedding_padding_modules" + ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_paddings, vocab_size, - self.lora_config, self.device) + self.lora_config, self.device, self.model.embedding_modules, + self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) def set_block_size(self, block_size: int) -> None: