Add LoRA support for Mixtral (#2831)
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
This commit is contained in:
parent
317b29de0f
commit
2a543d6efe
@ -121,6 +121,11 @@ def sql_lora_files():
|
|||||||
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def mixtral_lora_files():
|
||||||
|
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||||
cleanup()
|
cleanup()
|
||||||
|
@ -11,25 +11,35 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
|||||||
RowParallelLinearWithLoRA,
|
RowParallelLinearWithLoRA,
|
||||||
MergedColumnParallelLinearWithLoRA)
|
MergedColumnParallelLinearWithLoRA)
|
||||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
|
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||||
LRUCacheLoRAModelManager, LoRAMapping)
|
LRUCacheLoRAModelManager, LoRAMapping)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||||
WorkerLoRAManager)
|
WorkerLoRAManager)
|
||||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
|
|
||||||
|
EMBEDDING_MODULES = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
|
||||||
|
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
||||||
|
|
||||||
|
|
||||||
def test_from_lora_tensors(sql_lora_files):
|
def test_from_lora_tensors(sql_lora_files):
|
||||||
tensors = load_file(
|
tensors = load_file(
|
||||||
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||||
new_embeddings = load_file(
|
new_embeddings = load_file(
|
||||||
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
||||||
lora_model = LoRAModel.from_lora_tensors(1,
|
lora_model = LoRAModel.from_lora_tensors(
|
||||||
8,
|
1,
|
||||||
16,
|
8,
|
||||||
tensors,
|
16,
|
||||||
"cuda",
|
tensors,
|
||||||
embeddings=new_embeddings)
|
"cuda",
|
||||||
|
embeddings=new_embeddings,
|
||||||
|
embedding_modules=EMBEDDING_MODULES,
|
||||||
|
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
|
||||||
for module_name, lora in lora_model.loras.items():
|
for module_name, lora in lora_model.loras.items():
|
||||||
assert lora.module_name == module_name
|
assert lora.module_name == module_name
|
||||||
assert lora.rank == 8
|
assert lora.rank == 8
|
||||||
@ -90,14 +100,11 @@ def create_packed_lora(
|
|||||||
|
|
||||||
def test_replace_submodules(dist_init, dummy_model):
|
def test_replace_submodules(dist_init, dummy_model):
|
||||||
model = dummy_model
|
model = dummy_model
|
||||||
manager = LoRAModelManager(model,
|
model.supported_lora_modules = ["dense1", "layer1.dense2"]
|
||||||
1,
|
model.packed_modules_mapping = {}
|
||||||
1,
|
manager = LoRAModelManager(
|
||||||
1,
|
model, 1, 1, 1,
|
||||||
LoRAConfig(max_lora_rank=8,
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
|
||||||
max_cpu_loras=8,
|
|
||||||
max_loras=8),
|
|
||||||
lora_target_modules=["dense1", "layer1.dense2"])
|
|
||||||
model = manager.model
|
model = manager.model
|
||||||
|
|
||||||
assert isinstance(model.get_submodule("dense1"),
|
assert isinstance(model.get_submodule("dense1"),
|
||||||
@ -111,16 +118,14 @@ def test_replace_submodules(dist_init, dummy_model):
|
|||||||
|
|
||||||
def test_lora_model_manager(dist_init, dummy_model):
|
def test_lora_model_manager(dist_init, dummy_model):
|
||||||
model = dummy_model
|
model = dummy_model
|
||||||
|
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||||
|
model.packed_modules_mapping = {}
|
||||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||||
manager = LoRAModelManager(
|
manager = LoRAModelManager(
|
||||||
model,
|
model, 2, 2, 2,
|
||||||
2,
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||||
2,
|
|
||||||
2,
|
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
|
||||||
lora_target_modules=["dense1", "dense2", "lm_head"])
|
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
assert manager.add_lora(model_lora1)
|
assert manager.add_lora(model_lora1)
|
||||||
assert manager.activate_lora(1)
|
assert manager.activate_lora(1)
|
||||||
@ -159,16 +164,14 @@ def test_lora_model_manager(dist_init, dummy_model):
|
|||||||
|
|
||||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
||||||
model = dummy_model
|
model = dummy_model
|
||||||
|
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||||
|
model.packed_modules_mapping = {}
|
||||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||||
manager = LRUCacheLoRAModelManager(
|
manager = LRUCacheLoRAModelManager(
|
||||||
model,
|
model, 2, 2, 2,
|
||||||
2,
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||||
2,
|
|
||||||
2,
|
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
|
||||||
lora_target_modules=["dense1", "dense2", "lm_head"])
|
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
assert manager.add_lora(model_lora1)
|
assert manager.add_lora(model_lora1)
|
||||||
assert manager.activate_lora(1)
|
assert manager.activate_lora(1)
|
||||||
@ -212,14 +215,15 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
|
|||||||
# This tests just the LRU cache functionality, everything else is
|
# This tests just the LRU cache functionality, everything else is
|
||||||
# tested in test_lora_model_manager
|
# tested in test_lora_model_manager
|
||||||
model = dummy_model
|
model = dummy_model
|
||||||
|
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||||
|
model.packed_modules_mapping = {}
|
||||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||||
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
|
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
|
||||||
manager = LRUCacheLoRAModelManager(
|
manager = LRUCacheLoRAModelManager(
|
||||||
model, 2, 2, 2,
|
model, 2, 2, 2,
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
||||||
["dense1", "dense2", "lm_head"])
|
|
||||||
|
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
@ -289,8 +293,9 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
sql_lora_files):
|
sql_lora_files):
|
||||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||||
worker_lora_manager = LRUCacheWorkerLoRAManager(
|
worker_lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||||
torch.device("cuda"))
|
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||||
|
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||||
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||||
|
|
||||||
mapping = LoRAMapping([], [])
|
mapping = LoRAMapping([], [])
|
||||||
@ -362,8 +367,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
# Should remove every LoRA not specified in the request.
|
# Should remove every LoRA not specified in the request.
|
||||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||||
worker_lora_manager = WorkerLoRAManager(
|
worker_lora_manager = WorkerLoRAManager(
|
||||||
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||||
torch.device("cuda"))
|
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||||
|
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||||
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||||
|
|
||||||
mapping = LoRAMapping([], [])
|
mapping = LoRAMapping([], [])
|
||||||
@ -428,6 +434,13 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
|
|
||||||
def test_packed_loras(dist_init, dummy_model_gate_up):
|
def test_packed_loras(dist_init, dummy_model_gate_up):
|
||||||
model = dummy_model_gate_up
|
model = dummy_model_gate_up
|
||||||
|
model.supported_lora_modules = ["gate_up_proj"]
|
||||||
|
model.packed_modules_mapping = {
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
model_lora = create_packed_lora(
|
model_lora = create_packed_lora(
|
||||||
1,
|
1,
|
||||||
model,
|
model,
|
||||||
@ -443,8 +456,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
|
|||||||
|
|
||||||
manager = LoRAModelManager(
|
manager = LoRAModelManager(
|
||||||
model, 2, 2, 2,
|
model, 2, 2, 2,
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
||||||
["gate_up_proj"])
|
|
||||||
model = manager.model
|
model = manager.model
|
||||||
|
|
||||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||||
|
53
tests/lora/test_mixtral.py
Normal file
53
tests/lora/test_mixtral.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(llm, lora_path: str, lora_id: int):
|
||||||
|
prompts = [
|
||||||
|
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]",
|
||||||
|
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]",
|
||||||
|
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]",
|
||||||
|
]
|
||||||
|
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||||
|
if lora_id else None)
|
||||||
|
# Print the outputs.
|
||||||
|
generated_texts = []
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text.strip()
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tp_size", [4])
|
||||||
|
def test_mixtral_lora(mixtral_lora_files, tp_size):
|
||||||
|
if torch.cuda.device_count() < tp_size:
|
||||||
|
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||||
|
|
||||||
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
worker_use_ray=True)
|
||||||
|
|
||||||
|
expected_lora_output = [
|
||||||
|
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
|
||||||
|
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
|
||||||
|
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])",
|
||||||
|
]
|
||||||
|
|
||||||
|
assert do_sample(llm, mixtral_lora_files,
|
||||||
|
lora_id=1) == expected_lora_output
|
||||||
|
assert do_sample(llm, mixtral_lora_files,
|
||||||
|
lora_id=2) == expected_lora_output
|
@ -4,8 +4,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type,
|
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
|
||||||
Union)
|
|
||||||
|
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
@ -20,36 +19,6 @@ from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# TODO: The mappings below should be moved to individual model classes.
|
|
||||||
|
|
||||||
PACKED_MODULES_CFG = {
|
|
||||||
"qkv_proj": [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
],
|
|
||||||
"gate_up_proj": [
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
TARGET_MODULES_QKV = [
|
|
||||||
"qkv_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_up_proj",
|
|
||||||
"down_proj",
|
|
||||||
"embed_tokens",
|
|
||||||
"lm_head",
|
|
||||||
]
|
|
||||||
|
|
||||||
EMBEDDING_MODULES = {
|
|
||||||
"embed_tokens": "input_embeddings",
|
|
||||||
"lm_head": "output_embeddings",
|
|
||||||
}
|
|
||||||
|
|
||||||
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
|
||||||
|
|
||||||
_GLOBAL_LORA_ID = 0
|
_GLOBAL_LORA_ID = 0
|
||||||
|
|
||||||
|
|
||||||
@ -169,6 +138,8 @@ class LoRAModel:
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
target_embedding_padding: Optional[int] = None,
|
target_embedding_padding: Optional[int] = None,
|
||||||
|
embedding_modules: Optional[Dict[str, str]] = None,
|
||||||
|
embedding_padding_modules: Optional[List[str]] = None,
|
||||||
) -> "LoRAModel":
|
) -> "LoRAModel":
|
||||||
"""Create a LoRAModel from a dictionary of tensors."""
|
"""Create a LoRAModel from a dictionary of tensors."""
|
||||||
pin_memory = str(device) == "cpu" and not in_wsl()
|
pin_memory = str(device) == "cpu" and not in_wsl()
|
||||||
@ -179,11 +150,11 @@ class LoRAModel:
|
|||||||
lora_embeddings_tensor = None
|
lora_embeddings_tensor = None
|
||||||
if embeddings:
|
if embeddings:
|
||||||
embeddings_module = next(
|
embeddings_module = next(
|
||||||
(k for k in EMBEDDING_MODULES if k in module_name),
|
(k for k in embedding_modules if k in module_name),
|
||||||
None)
|
None)
|
||||||
if embeddings_module:
|
if embeddings_module:
|
||||||
lora_embeddings_tensor = embeddings[
|
lora_embeddings_tensor = embeddings[
|
||||||
EMBEDDING_MODULES[embeddings_module]].to(
|
embedding_modules[embeddings_module]].to(
|
||||||
device=device, dtype=dtype)
|
device=device, dtype=dtype)
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
lora_embeddings_tensor = (
|
lora_embeddings_tensor = (
|
||||||
@ -201,7 +172,7 @@ class LoRAModel:
|
|||||||
loras[module_name].lora_b = tensor.to(device=device,
|
loras[module_name].lora_b = tensor.to(device=device,
|
||||||
dtype=dtype).t()
|
dtype=dtype).t()
|
||||||
if any(name in module_name
|
if any(name in module_name
|
||||||
for name in EMBEDDING_PADDING_MODULES
|
for name in embedding_padding_modules
|
||||||
) and target_embedding_padding is not None:
|
) and target_embedding_padding is not None:
|
||||||
lora_b = loras[module_name].lora_b
|
lora_b = loras[module_name].lora_b
|
||||||
assert target_embedding_padding >= lora_b.shape[1]
|
assert target_embedding_padding >= lora_b.shape[1]
|
||||||
@ -218,12 +189,15 @@ class LoRAModel:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local_checkpoint(
|
def from_local_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
lora_dir: str,
|
lora_dir: str,
|
||||||
lora_model_id: Optional[int] = None,
|
lora_model_id: Optional[int] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
target_embedding_padding: Optional[int] = None) -> "LoRAModel":
|
target_embedding_padding: Optional[int] = None,
|
||||||
|
embedding_modules: Optional[Dict[str, str]] = None,
|
||||||
|
embedding_padding_modules: Optional[List[str]] = None,
|
||||||
|
) -> "LoRAModel":
|
||||||
"""Create a LoRAModel from a local checkpoint."""
|
"""Create a LoRAModel from a local checkpoint."""
|
||||||
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
||||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||||
@ -260,6 +234,8 @@ class LoRAModel:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
target_embedding_padding=target_embedding_padding,
|
target_embedding_padding=target_embedding_padding,
|
||||||
|
embedding_modules=embedding_modules,
|
||||||
|
embedding_padding_modules=embedding_padding_modules,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -273,8 +249,6 @@ class LoRAModelManager:
|
|||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
|
||||||
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
|
|
||||||
):
|
):
|
||||||
"""Create a LoRAModelManager and adapter for a given model.
|
"""Create a LoRAModelManager and adapter for a given model.
|
||||||
|
|
||||||
@ -286,13 +260,6 @@ class LoRAModelManager:
|
|||||||
in a single batch.
|
in a single batch.
|
||||||
vocab_size: the vocab size of the model.
|
vocab_size: the vocab size of the model.
|
||||||
lora_config: the LoRA configuration.
|
lora_config: the LoRA configuration.
|
||||||
lora_target_modules: the target modules patterns to be adapted.
|
|
||||||
Support both single module name and a list of module names.
|
|
||||||
packed_modules_mapping: the mapping for packed modules. vLLM
|
|
||||||
packs some modules into one module, e.g., qkv_proj
|
|
||||||
is packed of q_proj, k_proj, and v_proj. These modules
|
|
||||||
have a single layer in the original model, but they are split
|
|
||||||
into multiple layers in the adapted model.
|
|
||||||
"""
|
"""
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
@ -320,11 +287,11 @@ class LoRAModelManager:
|
|||||||
self.indices_len = [None] * 4
|
self.indices_len = [None] * 4
|
||||||
|
|
||||||
self.model: nn.Module = model
|
self.model: nn.Module = model
|
||||||
self.lora_target_modules: List[str] = ([
|
if hasattr(self.model, "supported_lora_modules"):
|
||||||
lora_target_modules
|
self.supported_lora_modules = copy.deepcopy(
|
||||||
] if isinstance(lora_target_modules, str) else lora_target_modules)
|
self.model.supported_lora_modules)
|
||||||
self.lora_target_modules = copy.deepcopy(lora_target_modules)
|
self.packed_modules_mapping = copy.deepcopy(
|
||||||
self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping)
|
self.model.packed_modules_mapping)
|
||||||
self.packed_modules: Dict[str, List[str]] = {}
|
self.packed_modules: Dict[str, List[str]] = {}
|
||||||
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
||||||
self._registered_loras: Dict[int, LoRAModel] = {}
|
self._registered_loras: Dict[int, LoRAModel] = {}
|
||||||
@ -468,7 +435,11 @@ class LoRAModelManager:
|
|||||||
assert isinstance(module, BaseLayerWithLoRA)
|
assert isinstance(module, BaseLayerWithLoRA)
|
||||||
self.modules[module_name] = module
|
self.modules[module_name] = module
|
||||||
|
|
||||||
def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel:
|
def create_dummy_lora(
|
||||||
|
self,
|
||||||
|
lora_id: int,
|
||||||
|
rank: int,
|
||||||
|
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
|
||||||
"""Create zero-initialized LoRAModel for warmup."""
|
"""Create zero-initialized LoRAModel for warmup."""
|
||||||
model = LoRAModel(lora_id, rank, {})
|
model = LoRAModel(lora_id, rank, {})
|
||||||
for module_name, module in self.model.named_modules():
|
for module_name, module in self.model.named_modules():
|
||||||
@ -477,7 +448,7 @@ class LoRAModelManager:
|
|||||||
continue
|
continue
|
||||||
parts = module_name.split(".")
|
parts = module_name.split(".")
|
||||||
if module_name not in self.packed_modules:
|
if module_name not in self.packed_modules:
|
||||||
if parts[-1] in EMBEDDING_MODULES:
|
if parts[-1] in embedding_modules:
|
||||||
input_dim = (module.base_layer.org_vocab_size +
|
input_dim = (module.base_layer.org_vocab_size +
|
||||||
self.lora_config.lora_extra_vocab_size if
|
self.lora_config.lora_extra_vocab_size if
|
||||||
hasattr(module.base_layer, "org_vocab_size")
|
hasattr(module.base_layer, "org_vocab_size")
|
||||||
@ -531,7 +502,7 @@ class LoRAModelManager:
|
|||||||
re.match(
|
re.match(
|
||||||
r".*\.{target_module}$".format(target_module=target_module),
|
r".*\.{target_module}$".format(target_module=target_module),
|
||||||
module_name) or target_module == module_name
|
module_name) or target_module == module_name
|
||||||
for target_module in self.lora_target_modules)
|
for target_module in self.supported_lora_modules)
|
||||||
|
|
||||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||||
parts = module_full_name.split(".")
|
parts = module_full_name.split(".")
|
||||||
@ -586,12 +557,9 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
|||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
|
||||||
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
|
|
||||||
):
|
):
|
||||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||||
vocab_size, lora_config, lora_target_modules,
|
vocab_size, lora_config)
|
||||||
packed_modules_mapping)
|
|
||||||
self._registered_loras: LoRALRUCache = LoRALRUCache(
|
self._registered_loras: LoRALRUCache = LoRALRUCache(
|
||||||
self.capacity, self.deactivate_lora)
|
self.capacity, self.deactivate_lora)
|
||||||
self._active_loras: LoRALRUCache = LoRALRUCache(
|
self._active_loras: LoRALRUCache = LoRALRUCache(
|
||||||
@ -637,11 +605,10 @@ def create_lora_manager(
|
|||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
|
||||||
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
||||||
**kwargs) -> LoRAModelManager:
|
**kwargs) -> LoRAModelManager:
|
||||||
"""Create a LoRA adapter for a given model."""
|
"""Create a LoRA adapter for a given model."""
|
||||||
if not getattr(model, "supports_lora", False):
|
if not hasattr(model, "supported_lora_modules"):
|
||||||
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
||||||
lora_manager = lora_manager_cls(
|
lora_manager = lora_manager_cls(
|
||||||
model=model,
|
model=model,
|
||||||
@ -649,6 +616,5 @@ def create_lora_manager(
|
|||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_target_modules=target_modules,
|
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return lora_manager
|
return lora_manager
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod, abstractproperty
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
from typing import Any, List, Optional, Set, Type, Union
|
from typing import Any, Dict, List, Optional, Set, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager,
|
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||||
LRUCacheLoRAModelManager, create_lora_manager)
|
LRUCacheLoRAModelManager, create_lora_manager)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
@ -13,7 +13,7 @@ from vllm.config import LoRAConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkerLoRAManager(ABC):
|
class AbstractWorkerLoRAManager(ABC):
|
||||||
"""Abstract class for managing LoRA models on the worker side."""
|
"""Abstract class for managing LoRA models on the worker side."""
|
||||||
|
|
||||||
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
|
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
|
||||||
@ -33,7 +33,6 @@ class WorkerLoRAManager(ABC):
|
|||||||
def create_lora_manager(
|
def create_lora_manager(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -63,7 +62,7 @@ class WorkerLoRAManager(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class WorkerLoRAManager(WorkerLoRAManager):
|
class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||||
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||||
|
|
||||||
Every request, the requested LoRAs will be loaded (unless they are already
|
Every request, the requested LoRAs will be loaded (unless they are already
|
||||||
@ -78,10 +77,14 @@ class WorkerLoRAManager(WorkerLoRAManager):
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
embedding_modules: Dict[str, str],
|
||||||
|
embedding_padding_modules: List[str],
|
||||||
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
||||||
):
|
):
|
||||||
self._lora_manager: Optional[LoRAModelManager] = None
|
self._lora_manager: Optional[LoRAModelManager] = None
|
||||||
self._lora_model_cls = lora_model_cls
|
self._lora_model_cls = lora_model_cls
|
||||||
|
self.embedding_modules = embedding_modules
|
||||||
|
self.embedding_padding_modules = embedding_padding_modules
|
||||||
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
||||||
lora_config, device)
|
lora_config, device)
|
||||||
|
|
||||||
@ -92,13 +95,11 @@ class WorkerLoRAManager(WorkerLoRAManager):
|
|||||||
def create_lora_manager(
|
def create_lora_manager(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
lora_manager = create_lora_manager(
|
lora_manager = create_lora_manager(
|
||||||
model,
|
model,
|
||||||
max_num_seqs=self.max_num_seqs,
|
max_num_seqs=self.max_num_seqs,
|
||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
target_modules=target_modules,
|
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
lora_manager_cls=self._lora_manager_cls,
|
lora_manager_cls=self._lora_manager_cls,
|
||||||
@ -142,6 +143,8 @@ class WorkerLoRAManager(WorkerLoRAManager):
|
|||||||
dtype=self.lora_config.lora_dtype,
|
dtype=self.lora_config.lora_dtype,
|
||||||
target_embedding_padding=self.vocab_size +
|
target_embedding_padding=self.vocab_size +
|
||||||
self.lora_config.lora_extra_vocab_size,
|
self.lora_config.lora_extra_vocab_size,
|
||||||
|
embedding_modules=self.embedding_modules,
|
||||||
|
embedding_padding_modules=self.embedding_padding_modules,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -162,7 +165,7 @@ class WorkerLoRAManager(WorkerLoRAManager):
|
|||||||
return False
|
return False
|
||||||
return self._lora_manager.add_lora(
|
return self._lora_manager.add_lora(
|
||||||
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
||||||
rank))
|
rank, self.embedding_modules))
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
if lora_request.lora_int_id in self.list_loras():
|
if lora_request.lora_int_id in self.list_loras():
|
||||||
@ -195,11 +198,9 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
|||||||
def create_lora_manager(
|
def create_lora_manager(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
lora_manager = create_lora_manager(
|
lora_manager = create_lora_manager(
|
||||||
model,
|
model,
|
||||||
target_modules=target_modules,
|
|
||||||
lora_manager_cls=self._lora_manager_cls,
|
lora_manager_cls=self._lora_manager_cls,
|
||||||
max_num_seqs=self.max_num_seqs,
|
max_num_seqs=self.max_num_seqs,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
|
@ -66,7 +66,7 @@ def get_model(model_config: ModelConfig,
|
|||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
# The weights will be initialized as empty tensors.
|
# The weights will be initialized as empty tensors.
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
if getattr(model_class, "supports_lora", False):
|
if hasattr(model_class, "supported_lora_modules"):
|
||||||
model = model_class(model_config.hf_config, linear_method,
|
model = model_class(model_config.hf_config, linear_method,
|
||||||
lora_config)
|
lora_config)
|
||||||
elif lora_config:
|
elif lora_config:
|
||||||
|
@ -269,7 +269,32 @@ class LlamaModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(nn.Module):
|
class LlamaForCausalLM(nn.Module):
|
||||||
supports_lora = True
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head",
|
||||||
|
]
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -281,11 +306,11 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
|
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
|
||||||
unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
if lora_config:
|
if lora_config:
|
||||||
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
unpadded_vocab_size,
|
self.unpadded_vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
org_num_embeddings=config.vocab_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
@ -293,7 +318,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
# compatibility
|
# compatibility
|
||||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
)
|
)
|
||||||
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
|
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -265,7 +265,32 @@ class MistralModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MistralForCausalLM(nn.Module):
|
class MistralForCausalLM(nn.Module):
|
||||||
supports_lora = True
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head",
|
||||||
|
]
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -27,6 +27,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
|
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
@ -38,7 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
@ -292,6 +293,7 @@ class MixtralModel(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
org_num_embeddings=self.org_vocab_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
MixtralDecoderLayer(config, linear_method=linear_method)
|
MixtralDecoderLayer(config, linear_method=linear_method)
|
||||||
@ -318,18 +320,50 @@ class MixtralModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MixtralForCausalLM(nn.Module):
|
class MixtralForCausalLM(nn.Module):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head",
|
||||||
|
]
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = MixtralModel(config, linear_method)
|
self.model = MixtralModel(config, linear_method)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
self.sampler = Sampler(config.vocab_size)
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
|
)
|
||||||
|
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -86,11 +86,20 @@ class ModelRunner:
|
|||||||
vocab_size = self.model.config.vocab_size
|
vocab_size = self.model.config.vocab_size
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
|
assert hasattr(
|
||||||
|
self.model, "supported_lora_modules"
|
||||||
|
) and self.model.supported_lora_modules, "Model does not support LoRA"
|
||||||
|
assert hasattr(
|
||||||
|
self.model,
|
||||||
|
"embedding_modules"), "Model does not have embedding_modules"
|
||||||
|
assert hasattr(self.model, "embedding_padding_modules"
|
||||||
|
), "Model does not have embedding_padding_modules"
|
||||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
self.scheduler_config.max_num_seqs,
|
self.scheduler_config.max_num_seqs,
|
||||||
self.scheduler_config.max_num_batched_tokens +
|
self.scheduler_config.max_num_batched_tokens +
|
||||||
self.scheduler_config.max_paddings, vocab_size,
|
self.scheduler_config.max_paddings, vocab_size,
|
||||||
self.lora_config, self.device)
|
self.lora_config, self.device, self.model.embedding_modules,
|
||||||
|
self.model.embedding_padding_modules)
|
||||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
|
|
||||||
def set_block_size(self, block_size: int) -> None:
|
def set_block_size(self, block_size: int) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user