[Misc][LoRA] Replace hardcoded cuda device with configurable argument (#10223)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
eea55cca5b
commit
7f5edb5900
@ -51,6 +51,7 @@ TOLERANCES = {
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
# We will launch different triton kernels between the prefill and decode
|
||||
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
|
||||
STAGES = [True, False]
|
||||
@ -120,11 +121,12 @@ def populate_loras(
|
||||
subloras: List[LoRALayerWeights] = []
|
||||
sublora_len = layer_weights.shape[0] // repeats
|
||||
for i in range(repeats):
|
||||
sublora = DummyLoRAManager().init_random_lora(
|
||||
module_name=f"fake_{i}",
|
||||
weight=layer_weights,
|
||||
generate_embeddings_tensor=generate_embeddings_tensor,
|
||||
)
|
||||
sublora = DummyLoRAManager(
|
||||
layer_weights.device).init_random_lora(
|
||||
module_name=f"fake_{i}",
|
||||
weight=layer_weights,
|
||||
generate_embeddings_tensor=generate_embeddings_tensor,
|
||||
)
|
||||
sublora.lora_b = sublora.lora_b[:, (sublora_len *
|
||||
i):(sublora_len * (i + 1))]
|
||||
sublora.optimize()
|
||||
@ -152,6 +154,7 @@ def create_random_inputs(
|
||||
input_size: Tuple[int, ...],
|
||||
input_range: Tuple[float, float],
|
||||
input_type: torch.dtype = torch.int,
|
||||
device: torch.device = "cuda"
|
||||
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
|
||||
"""Creates random inputs.
|
||||
|
||||
@ -173,10 +176,14 @@ def create_random_inputs(
|
||||
for _ in range(num_inputs):
|
||||
if input_type == torch.int:
|
||||
inputs.append(
|
||||
torch.randint(low=int(low), high=int(high), size=input_size))
|
||||
torch.randint(low=int(low),
|
||||
high=int(high),
|
||||
size=input_size,
|
||||
device=device))
|
||||
else:
|
||||
inputs.append(
|
||||
torch.rand(size=input_size, dtype=input_type) * high + low)
|
||||
torch.rand(size=input_size, dtype=input_type, device=device) *
|
||||
high + low)
|
||||
|
||||
lora_id = random.choice(active_lora_ids)
|
||||
index_mapping += [lora_id] * input_size[0]
|
||||
@ -191,6 +198,10 @@ def create_random_inputs(
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
|
||||
# device, see: https://github.com/triton-lang/triton/issues/2925
|
||||
# Same below.
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@ -225,7 +236,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -263,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -291,6 +302,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
vocab_size, stage) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
@ -345,7 +357,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -400,7 +412,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
device=device)
|
||||
original_inputs = deepcopy(inputs)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
@ -426,6 +438,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
stage) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
@ -471,7 +484,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
input_size=(1, 1024),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -520,7 +533,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
input_size=(1, 1024),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -554,6 +567,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
max_loras = 8
|
||||
@ -592,7 +606,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -631,7 +645,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -658,6 +672,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
||||
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
device, stage) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
max_loras = 8
|
||||
@ -706,7 +721,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -745,7 +760,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -772,6 +787,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
device, stage) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
max_loras = 8
|
||||
@ -842,7 +858,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -883,7 +899,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
device=device)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
@ -962,7 +978,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
input_size=(1, max_position),
|
||||
input_range=(0, lora_config.lora_extra_vocab_size),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
device=device)
|
||||
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
long_lora_context = LongContextLoRAContext(list(scaling_factors),
|
||||
|
@ -25,8 +25,13 @@ EMBEDDING_MODULES = {
|
||||
|
||||
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
def test_from_lora_tensors(sql_lora_files):
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(
|
||||
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||
new_embeddings = load_file(
|
||||
@ -36,7 +41,7 @@ def test_from_lora_tensors(sql_lora_files):
|
||||
8,
|
||||
16,
|
||||
tensors,
|
||||
"cuda",
|
||||
device,
|
||||
embeddings=new_embeddings,
|
||||
embedding_modules=EMBEDDING_MODULES,
|
||||
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
|
||||
@ -46,6 +51,8 @@ def test_from_lora_tensors(sql_lora_files):
|
||||
assert lora.lora_alpha == 16
|
||||
assert lora.lora_a is not None
|
||||
assert lora.lora_b is not None
|
||||
assert lora.lora_a.device == torch.device(device)
|
||||
assert lora.lora_b.device == torch.device(device)
|
||||
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
|
||||
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||
assert lora.lora_a.shape[1] == 8
|
||||
@ -60,8 +67,8 @@ def test_from_lora_tensors(sql_lora_files):
|
||||
assert lora.embeddings_tensor is None
|
||||
|
||||
|
||||
def create_lora(lora_id: int, model: nn.Module,
|
||||
sub_modules: List[str]) -> LoRAModel:
|
||||
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
|
||||
device: torch.device) -> LoRAModel:
|
||||
loras: Dict[str, LoRALayerWeights] = {}
|
||||
for name in sub_modules:
|
||||
w = model.get_submodule(name).weight
|
||||
@ -69,8 +76,8 @@ def create_lora(lora_id: int, model: nn.Module,
|
||||
name,
|
||||
8,
|
||||
16,
|
||||
torch.rand([w.shape[1], 8], device="cuda"),
|
||||
torch.rand([8, w.shape[0]], device="cuda"),
|
||||
torch.rand([w.shape[1], 8], device=device),
|
||||
torch.rand([8, w.shape[0]], device=device),
|
||||
)
|
||||
return LoRAModel(lora_id, 8, loras)
|
||||
|
||||
@ -80,6 +87,7 @@ def create_packed_lora(
|
||||
model: nn.Module,
|
||||
module_name,
|
||||
replaced_module_names,
|
||||
device: torch.device,
|
||||
empty_replaced_module_name=None,
|
||||
) -> LoRAModel:
|
||||
w = model.get_submodule(module_name).weight
|
||||
@ -91,9 +99,9 @@ def create_packed_lora(
|
||||
replaced_module_name,
|
||||
8,
|
||||
16,
|
||||
torch.rand([w.shape[1], 8], device="cuda"),
|
||||
torch.rand([w.shape[1], 8], device=device),
|
||||
torch.rand([8, w.shape[0] // len(replaced_module_names)],
|
||||
device="cuda"),
|
||||
device=device),
|
||||
)
|
||||
return LoRAModel(lora_id, 8, loras)
|
||||
|
||||
@ -104,7 +112,8 @@ def test_replace_submodules(dist_init, dummy_model):
|
||||
model.packed_modules_mapping = {}
|
||||
manager = LoRAModelManager(
|
||||
model, 1, 1, 1,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
|
||||
torch.device("cuda"))
|
||||
model = manager.model
|
||||
|
||||
assert isinstance(model.get_submodule("dense1"),
|
||||
@ -116,16 +125,28 @@ def test_replace_submodules(dist_init, dummy_model):
|
||||
RowParallelLinearWithLoRA)
|
||||
|
||||
|
||||
def test_lora_model_manager(dist_init, dummy_model):
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||
manager = LoRAModelManager(
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||
model_lora1 = create_lora(1,
|
||||
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora2 = create_lora(2,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora3 = create_lora(3,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
manager = LoRAModelManager(model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=3,
|
||||
max_loras=2),
|
||||
device=device)
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.activate_adapter(1)
|
||||
@ -161,17 +182,32 @@ def test_lora_model_manager(dist_init, dummy_model):
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
assert manager.device == device
|
||||
assert manager.punica_wrapper.device == device
|
||||
|
||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||
manager = LRUCacheLoRAModelManager(
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||
model_lora1 = create_lora(1,
|
||||
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora2 = create_lora(2,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora3 = create_lora(3,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
manager = LRUCacheLoRAModelManager(model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=3,
|
||||
max_loras=2),
|
||||
device=device)
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.activate_adapter(1)
|
||||
@ -238,20 +274,37 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.pin_adapter(3)
|
||||
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert manager.device == device
|
||||
|
||||
def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
# This tests just the LRU cache functionality, everything else is
|
||||
# tested in test_lora_model_manager
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
|
||||
manager = LRUCacheLoRAModelManager(
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
||||
model_lora1 = create_lora(1,
|
||||
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora2 = create_lora(2,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora3 = create_lora(3,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora4 = create_lora(4,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
manager = LRUCacheLoRAModelManager(model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_loras=2),
|
||||
device=device)
|
||||
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
@ -351,14 +404,17 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||
assert manager.remove_oldest_adapter()
|
||||
|
||||
assert set(manager.list_adapters()) == {1}
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files):
|
||||
sql_lora_files, device):
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
@ -426,14 +482,19 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
LoRARequest("14", 14, sql_lora_files)
|
||||
], mapping)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
|
||||
device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files):
|
||||
sql_lora_files, device):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
@ -497,8 +558,13 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
LoRARequest("14", 14, sql_lora_files)
|
||||
], mapping)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
|
||||
device)
|
||||
|
||||
def test_packed_loras(dist_init, dummy_model_gate_up):
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
model = dummy_model_gate_up
|
||||
model.supported_lora_modules = ["gate_up_proj"]
|
||||
model.packed_modules_mapping = {
|
||||
@ -511,18 +577,25 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
|
||||
1,
|
||||
model,
|
||||
module_name="gate_up_proj",
|
||||
replaced_module_names=["gate_proj", "up_proj"])
|
||||
replaced_module_names=["gate_proj", "up_proj"],
|
||||
device=device)
|
||||
model_lora1 = create_packed_lora(
|
||||
2,
|
||||
model,
|
||||
module_name="gate_up_proj",
|
||||
replaced_module_names=["gate_proj", "up_proj"],
|
||||
device=device,
|
||||
empty_replaced_module_name="gate_proj",
|
||||
)
|
||||
|
||||
manager = LoRAModelManager(
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
||||
manager = LoRAModelManager(model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_loras=2),
|
||||
device=device)
|
||||
model = manager.model
|
||||
|
||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||
|
@ -7,9 +7,10 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
|
||||
class DummyLoRAManager:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, device: torch.device = "cuda:0"):
|
||||
super().__init__()
|
||||
self._loras: Dict[str, LoRALayerWeights] = {}
|
||||
self._device = device
|
||||
|
||||
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
|
||||
self._loras[module_name] = lora
|
||||
@ -28,16 +29,16 @@ class DummyLoRAManager:
|
||||
lora_alpha=1,
|
||||
lora_a=torch.rand([weight.shape[1], rank],
|
||||
dtype=weight.dtype,
|
||||
device="cuda"),
|
||||
device=self._device),
|
||||
lora_b=torch.rand([rank, weight.shape[0]],
|
||||
dtype=weight.dtype,
|
||||
device="cuda"),
|
||||
device=self._device),
|
||||
)
|
||||
if generate_embeddings_tensor:
|
||||
lora.embeddings_tensor = torch.rand(5,
|
||||
generate_embeddings_tensor,
|
||||
dtype=weight.dtype,
|
||||
device="cuda")
|
||||
device=self._device)
|
||||
self.set_module_lora(module_name, lora)
|
||||
|
||||
return lora
|
||||
|
@ -301,6 +301,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Create a LoRAModelManager and adapter for a given model.
|
||||
|
||||
@ -314,6 +315,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
lora_config: the LoRA configuration.
|
||||
"""
|
||||
self.lora_config = lora_config
|
||||
self.device = device
|
||||
self.max_num_seqs = max_num_seqs
|
||||
assert self.capacity >= self.lora_slots
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
@ -322,7 +324,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
||||
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device="cuda")
|
||||
device=self.device)
|
||||
# Scaling factor -> offset to the sin_cos_cache to it.
|
||||
# Used for long context lora.
|
||||
self.scaling_factor_to_offset: Dict[float, int] = {}
|
||||
@ -653,16 +655,11 @@ class LoRALRUCache(AdapterLRUCache[LoRAModel]):
|
||||
class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
"""A model manager that manages multiple LoRAs with LRU cache."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
):
|
||||
def __init__(self, model: nn.Module, max_num_seqs: int,
|
||||
max_num_batched_tokens: int, vocab_size: int,
|
||||
lora_config: LoRAConfig, device: torch.device):
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
vocab_size, lora_config)
|
||||
vocab_size, lora_config, device)
|
||||
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_adapter)
|
||||
self._active_adapters: LoRALRUCache = LoRALRUCache(
|
||||
@ -732,6 +729,7 @@ def create_lora_manager(
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs) -> LoRAModelManager:
|
||||
"""Create a LoRA adapter for a given model."""
|
||||
@ -743,5 +741,6 @@ def create_lora_manager(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
vocab_size=vocab_size,
|
||||
lora_config=lora_config,
|
||||
device=device,
|
||||
**kwargs)
|
||||
return lora_manager
|
||||
|
@ -62,6 +62,7 @@ def convert_mapping(
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
device: torch.device,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor], List[int]]:
|
||||
@ -104,7 +105,7 @@ def convert_mapping(
|
||||
long_lora_offsets: Optional[torch.Tensor] = None
|
||||
if long_lora_context:
|
||||
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
||||
device="cuda",
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
prompt_mapping: List[int] = [
|
||||
lora_index_to_id.index(x) if x > 0 else -1
|
||||
@ -131,10 +132,10 @@ def convert_mapping(
|
||||
if long_lora_context:
|
||||
assert long_lora_offsets is not None
|
||||
indices_list.append(long_lora_offsets)
|
||||
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
|
||||
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
|
||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||
device="cuda",
|
||||
dtype=torch.long)
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
embeddings_indices = torch.stack([
|
||||
indices[2] * extra_vocab_size,
|
||||
indices[2] * (vocab_size + extra_vocab_size),
|
||||
@ -145,7 +146,7 @@ def convert_mapping(
|
||||
sampler_indices_padded = sampler_indices.clone()
|
||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||
sampler_indices_padded = torch.arange(
|
||||
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
|
||||
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
|
||||
sampler_indices_padded * len(sampler_indices_padded))
|
||||
long_lora_indices = None
|
||||
long_lora_indices_len: Optional[int] = None
|
||||
@ -183,7 +184,7 @@ class PunicaWrapper:
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||
device: str):
|
||||
device: Union[torch.device, str]):
|
||||
self._token_lora_indices = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
@ -215,6 +216,7 @@ class PunicaWrapper:
|
||||
self._lora_indices_per_batch = torch.empty(max_batches,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.device: torch.device = device
|
||||
self.max_length: int = 0
|
||||
self.token_nums: int = 0
|
||||
self.batch_size: int = -1
|
||||
@ -263,6 +265,7 @@ class PunicaWrapper:
|
||||
max_loras,
|
||||
vocab_size,
|
||||
extra_vocab_size,
|
||||
self.device,
|
||||
long_lora_context,
|
||||
)
|
||||
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||
|
@ -73,6 +73,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
lora_manager_cls=self._manager_cls,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
@ -176,6 +177,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
|
Loading…
x
Reference in New Issue
Block a user