[Misc][LoRA] Replace hardcoded cuda device with configurable argument (#10223)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-12 11:10:15 +08:00 committed by GitHub
parent eea55cca5b
commit 7f5edb5900
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 174 additions and 80 deletions

View File

@ -51,6 +51,7 @@ TOLERANCES = {
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
# We will launch different triton kernels between the prefill and decode # We will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False) # stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False] STAGES = [True, False]
@ -120,7 +121,8 @@ def populate_loras(
subloras: List[LoRALayerWeights] = [] subloras: List[LoRALayerWeights] = []
sublora_len = layer_weights.shape[0] // repeats sublora_len = layer_weights.shape[0] // repeats
for i in range(repeats): for i in range(repeats):
sublora = DummyLoRAManager().init_random_lora( sublora = DummyLoRAManager(
layer_weights.device).init_random_lora(
module_name=f"fake_{i}", module_name=f"fake_{i}",
weight=layer_weights, weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor, generate_embeddings_tensor=generate_embeddings_tensor,
@ -152,6 +154,7 @@ def create_random_inputs(
input_size: Tuple[int, ...], input_size: Tuple[int, ...],
input_range: Tuple[float, float], input_range: Tuple[float, float],
input_type: torch.dtype = torch.int, input_type: torch.dtype = torch.int,
device: torch.device = "cuda"
) -> Tuple[List[torch.Tensor], List[int], List[int]]: ) -> Tuple[List[torch.Tensor], List[int], List[int]]:
"""Creates random inputs. """Creates random inputs.
@ -173,10 +176,14 @@ def create_random_inputs(
for _ in range(num_inputs): for _ in range(num_inputs):
if input_type == torch.int: if input_type == torch.int:
inputs.append( inputs.append(
torch.randint(low=int(low), high=int(high), size=input_size)) torch.randint(low=int(low),
high=int(high),
size=input_size,
device=device))
else: else:
inputs.append( inputs.append(
torch.rand(size=input_size, dtype=input_type) * high + low) torch.rand(size=input_size, dtype=input_type, device=device) *
high + low)
lora_id = random.choice(active_lora_ids) lora_id = random.choice(active_lora_ids)
index_mapping += [lora_id] * input_size[0] index_mapping += [lora_id] * input_size[0]
@ -191,6 +198,10 @@ def create_random_inputs(
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
# device, see: https://github.com/triton-lang/triton/issues/2925
# Same below.
torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
@ -225,7 +236,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
num_inputs=num_loras * 3, num_inputs=num_loras * 3,
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -263,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
num_inputs=num_loras * 3, num_inputs=num_loras * 3,
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -291,6 +302,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
def test_embeddings_with_new_embeddings(dist_init, num_loras, device, def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size, stage) -> None: vocab_size, stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = PunicaWrapper(8192, 256, device)
@ -345,7 +357,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
num_inputs=num_loras * 3, num_inputs=num_loras * 3,
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -400,7 +412,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
num_inputs=num_loras * 3, num_inputs=num_loras * 3,
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) device=device)
original_inputs = deepcopy(inputs) original_inputs = deepcopy(inputs)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
@ -426,6 +438,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
stage) -> None: stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = PunicaWrapper(8192, 256, device)
@ -471,7 +484,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
input_size=(1, 1024), input_size=(1, 1024),
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -520,7 +533,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
input_size=(1, 1024), input_size=(1, 1024),
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -554,6 +567,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
@pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None: def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
@ -592,7 +606,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
input_size=(1, 4096), input_size=(1, 4096),
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -631,7 +645,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
input_size=(1, 4096), input_size=(1, 4096),
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -658,6 +672,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage) -> None: device, stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
@ -706,7 +721,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_size=(1, 4096), input_size=(1, 4096),
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -745,7 +760,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_size=(1, 4096), input_size=(1, 4096),
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -772,6 +787,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage) -> None: device, stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
@ -842,7 +858,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_size=(1, 4096), input_size=(1, 4096),
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -883,7 +899,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_size=(1, 4096), input_size=(1, 4096),
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) device=device)
lora_mapping = LoRAMapping(index_mapping, lora_mapping = LoRAMapping(index_mapping,
prompt_mapping, prompt_mapping,
is_prefill=stage) is_prefill=stage)
@ -962,7 +978,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
input_size=(1, max_position), input_size=(1, max_position),
input_range=(0, lora_config.lora_extra_vocab_size), input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.float16, input_type=torch.float16,
) device=device)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
long_lora_context = LongContextLoRAContext(list(scaling_factors), long_lora_context = LongContextLoRAContext(list(scaling_factors),

View File

@ -25,8 +25,13 @@ EMBEDDING_MODULES = {
EMBEDDING_PADDING_MODULES = ["lm_head"] EMBEDDING_PADDING_MODULES = ["lm_head"]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def test_from_lora_tensors(sql_lora_files):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file( tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors")) os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file( new_embeddings = load_file(
@ -36,7 +41,7 @@ def test_from_lora_tensors(sql_lora_files):
8, 8,
16, 16,
tensors, tensors,
"cuda", device,
embeddings=new_embeddings, embeddings=new_embeddings,
embedding_modules=EMBEDDING_MODULES, embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES) embedding_padding_modules=EMBEDDING_PADDING_MODULES)
@ -46,6 +51,8 @@ def test_from_lora_tensors(sql_lora_files):
assert lora.lora_alpha == 16 assert lora.lora_alpha == 16
assert lora.lora_a is not None assert lora.lora_a is not None
assert lora.lora_b is not None assert lora.lora_b is not None
assert lora.lora_a.device == torch.device(device)
assert lora.lora_b.device == torch.device(device)
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
assert lora.lora_a.shape[1] == 8 assert lora.lora_a.shape[1] == 8
@ -60,8 +67,8 @@ def test_from_lora_tensors(sql_lora_files):
assert lora.embeddings_tensor is None assert lora.embeddings_tensor is None
def create_lora(lora_id: int, model: nn.Module, def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
sub_modules: List[str]) -> LoRAModel: device: torch.device) -> LoRAModel:
loras: Dict[str, LoRALayerWeights] = {} loras: Dict[str, LoRALayerWeights] = {}
for name in sub_modules: for name in sub_modules:
w = model.get_submodule(name).weight w = model.get_submodule(name).weight
@ -69,8 +76,8 @@ def create_lora(lora_id: int, model: nn.Module,
name, name,
8, 8,
16, 16,
torch.rand([w.shape[1], 8], device="cuda"), torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0]], device="cuda"), torch.rand([8, w.shape[0]], device=device),
) )
return LoRAModel(lora_id, 8, loras) return LoRAModel(lora_id, 8, loras)
@ -80,6 +87,7 @@ def create_packed_lora(
model: nn.Module, model: nn.Module,
module_name, module_name,
replaced_module_names, replaced_module_names,
device: torch.device,
empty_replaced_module_name=None, empty_replaced_module_name=None,
) -> LoRAModel: ) -> LoRAModel:
w = model.get_submodule(module_name).weight w = model.get_submodule(module_name).weight
@ -91,9 +99,9 @@ def create_packed_lora(
replaced_module_name, replaced_module_name,
8, 8,
16, 16,
torch.rand([w.shape[1], 8], device="cuda"), torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0] // len(replaced_module_names)], torch.rand([8, w.shape[0] // len(replaced_module_names)],
device="cuda"), device=device),
) )
return LoRAModel(lora_id, 8, loras) return LoRAModel(lora_id, 8, loras)
@ -104,7 +112,8 @@ def test_replace_submodules(dist_init, dummy_model):
model.packed_modules_mapping = {} model.packed_modules_mapping = {}
manager = LoRAModelManager( manager = LoRAModelManager(
model, 1, 1, 1, model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8)) LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device("cuda"))
model = manager.model model = manager.model
assert isinstance(model.get_submodule("dense1"), assert isinstance(model.get_submodule("dense1"),
@ -116,16 +125,28 @@ def test_replace_submodules(dist_init, dummy_model):
RowParallelLinearWithLoRA) RowParallelLinearWithLoRA)
def test_lora_model_manager(dist_init, dummy_model): @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
model = dummy_model model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"] model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {} model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) model_lora1 = create_lora(1,
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) model, ["layer1.dense1", "dense2", "lm_head"],
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) device=device)
manager = LoRAModelManager( model_lora2 = create_lora(2,
model, 2, 2, 2, model, ["dense1", "dense2", "lm_head"],
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) device=device)
model_lora3 = create_lora(3,
model, ["dense1", "dense2", "lm_head"],
device=device)
manager = LoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
device=device)
assert all(x is None for x in manager.lora_index_to_id) assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1) assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1) assert manager.activate_adapter(1)
@ -161,17 +182,32 @@ def test_lora_model_manager(dist_init, dummy_model):
assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2 assert manager.lora_index_to_id[1] == 2
assert manager.device == device
assert manager.punica_wrapper.device == device
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
model = dummy_model model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"] model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {} model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) model_lora1 = create_lora(1,
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) model, ["layer1.dense1", "dense2", "lm_head"],
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) device=device)
manager = LRUCacheLoRAModelManager( model_lora2 = create_lora(2,
model, 2, 2, 2, model, ["dense1", "dense2", "lm_head"],
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) device=device)
model_lora3 = create_lora(3,
model, ["dense1", "dense2", "lm_head"],
device=device)
manager = LRUCacheLoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
device=device)
assert all(x is None for x in manager.lora_index_to_id) assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1) assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1) assert manager.activate_adapter(1)
@ -238,20 +274,37 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert manager.pin_adapter(3) assert manager.pin_adapter(3)
assert manager.punica_wrapper.device == device
assert manager.device == device
def test_lru_lora_model_manager(dist_init, dummy_model):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is # This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager # tested in test_lora_model_manager
model = dummy_model model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"] model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {} model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) model_lora1 = create_lora(1,
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) model, ["layer1.dense1", "dense2", "lm_head"],
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) device=device)
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"]) model_lora2 = create_lora(2,
manager = LRUCacheLoRAModelManager( model, ["dense1", "dense2", "lm_head"],
model, 2, 2, 2, device=device)
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2)) model_lora3 = create_lora(3,
model, ["dense1", "dense2", "lm_head"],
device=device)
model_lora4 = create_lora(4,
model, ["dense1", "dense2", "lm_head"],
device=device)
manager = LRUCacheLoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
device=device)
assert all(x is None for x in manager.lora_index_to_id) assert all(x is None for x in manager.lora_index_to_id)
@ -351,14 +404,17 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert manager.remove_oldest_adapter() assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1} assert set(manager.list_adapters()) == {1}
assert manager.punica_wrapper.device == device
assert manager.device == device
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files): sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_adapter_manager = LRUCacheWorkerLoRAManager( worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), lora_config.lora_extra_vocab_size, lora_config, device,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.create_lora_manager( worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings) llama_2_7b_model_extra_embeddings)
@ -426,14 +482,19 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
LoRARequest("14", 14, sql_lora_files) LoRARequest("14", 14, sql_lora_files)
], mapping) ], mapping)
assert worker_adapter_manager.device == device
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
device)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files): sql_lora_files, device):
# Should remove every LoRA not specified in the request. # Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_adapter_manager = WorkerLoRAManager( worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), lora_config.lora_extra_vocab_size, lora_config, device,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.create_lora_manager( worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings) llama_2_7b_model_extra_embeddings)
@ -497,8 +558,13 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
LoRARequest("14", 14, sql_lora_files) LoRARequest("14", 14, sql_lora_files)
], mapping) ], mapping)
assert worker_adapter_manager.device == device
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
device)
def test_packed_loras(dist_init, dummy_model_gate_up):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"] model.supported_lora_modules = ["gate_up_proj"]
model.packed_modules_mapping = { model.packed_modules_mapping = {
@ -511,18 +577,25 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
1, 1,
model, model,
module_name="gate_up_proj", module_name="gate_up_proj",
replaced_module_names=["gate_proj", "up_proj"]) replaced_module_names=["gate_proj", "up_proj"],
device=device)
model_lora1 = create_packed_lora( model_lora1 = create_packed_lora(
2, 2,
model, model,
module_name="gate_up_proj", module_name="gate_up_proj",
replaced_module_names=["gate_proj", "up_proj"], replaced_module_names=["gate_proj", "up_proj"],
device=device,
empty_replaced_module_name="gate_proj", empty_replaced_module_name="gate_proj",
) )
manager = LoRAModelManager( manager = LoRAModelManager(model,
model, 2, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2)) 2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
device=device)
model = manager.model model = manager.model
assert isinstance(model.get_submodule("gate_up_proj"), assert isinstance(model.get_submodule("gate_up_proj"),

View File

@ -7,9 +7,10 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
class DummyLoRAManager: class DummyLoRAManager:
def __init__(self): def __init__(self, device: torch.device = "cuda:0"):
super().__init__() super().__init__()
self._loras: Dict[str, LoRALayerWeights] = {} self._loras: Dict[str, LoRALayerWeights] = {}
self._device = device
def set_module_lora(self, module_name: str, lora: LoRALayerWeights): def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
self._loras[module_name] = lora self._loras[module_name] = lora
@ -28,16 +29,16 @@ class DummyLoRAManager:
lora_alpha=1, lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank], lora_a=torch.rand([weight.shape[1], rank],
dtype=weight.dtype, dtype=weight.dtype,
device="cuda"), device=self._device),
lora_b=torch.rand([rank, weight.shape[0]], lora_b=torch.rand([rank, weight.shape[0]],
dtype=weight.dtype, dtype=weight.dtype,
device="cuda"), device=self._device),
) )
if generate_embeddings_tensor: if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(5, lora.embeddings_tensor = torch.rand(5,
generate_embeddings_tensor, generate_embeddings_tensor,
dtype=weight.dtype, dtype=weight.dtype,
device="cuda") device=self._device)
self.set_module_lora(module_name, lora) self.set_module_lora(module_name, lora)
return lora return lora

View File

@ -301,6 +301,7 @@ class LoRAModelManager(AdapterModelManager):
max_num_batched_tokens: int, max_num_batched_tokens: int,
vocab_size: int, vocab_size: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
device: torch.device,
): ):
"""Create a LoRAModelManager and adapter for a given model. """Create a LoRAModelManager and adapter for a given model.
@ -314,6 +315,7 @@ class LoRAModelManager(AdapterModelManager):
lora_config: the LoRA configuration. lora_config: the LoRA configuration.
""" """
self.lora_config = lora_config self.lora_config = lora_config
self.device = device
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
@ -322,7 +324,7 @@ class LoRAModelManager(AdapterModelManager):
self.long_lora_context: Optional[LongContextLoRAContext] = None self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens, self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs, max_batches=self.max_num_seqs,
device="cuda") device=self.device)
# Scaling factor -> offset to the sin_cos_cache to it. # Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora. # Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {} self.scaling_factor_to_offset: Dict[float, int] = {}
@ -653,16 +655,11 @@ class LoRALRUCache(AdapterLRUCache[LoRAModel]):
class LRUCacheLoRAModelManager(LoRAModelManager): class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache.""" """A model manager that manages multiple LoRAs with LRU cache."""
def __init__( def __init__(self, model: nn.Module, max_num_seqs: int,
self, max_num_batched_tokens: int, vocab_size: int,
model: nn.Module, lora_config: LoRAConfig, device: torch.device):
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
):
super().__init__(model, max_num_seqs, max_num_batched_tokens, super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config) vocab_size, lora_config, device)
self._registered_adapters: LoRALRUCache = LoRALRUCache( self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter) self.capacity, self.deactivate_adapter)
self._active_adapters: LoRALRUCache = LoRALRUCache( self._active_adapters: LoRALRUCache = LoRALRUCache(
@ -732,6 +729,7 @@ def create_lora_manager(
max_num_batched_tokens: int, max_num_batched_tokens: int,
vocab_size: int, vocab_size: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
device: torch.device,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager: **kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model.""" """Create a LoRA adapter for a given model."""
@ -743,5 +741,6 @@ def create_lora_manager(
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size, vocab_size=vocab_size,
lora_config=lora_config, lora_config=lora_config,
device=device,
**kwargs) **kwargs)
return lora_manager return lora_manager

View File

@ -62,6 +62,7 @@ def convert_mapping(
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None, long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]: Optional[torch.Tensor], List[int]]:
@ -104,7 +105,7 @@ def convert_mapping(
long_lora_offsets: Optional[torch.Tensor] = None long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context: if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices), long_lora_offsets = torch.zeros(len(index_mapping_indices),
device="cuda", device=device,
dtype=torch.long) dtype=torch.long)
prompt_mapping: List[int] = [ prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1 lora_index_to_id.index(x) if x > 0 else -1
@ -131,10 +132,10 @@ def convert_mapping(
if long_lora_context: if long_lora_context:
assert long_lora_offsets is not None assert long_lora_offsets is not None
indices_list.append(long_lora_offsets) indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping, prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda", dtype=torch.long,
dtype=torch.long) device=device)
embeddings_indices = torch.stack([ embeddings_indices = torch.stack([
indices[2] * extra_vocab_size, indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size), indices[2] * (vocab_size + extra_vocab_size),
@ -145,7 +146,7 @@ def convert_mapping(
sampler_indices_padded = sampler_indices.clone() sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = torch.arange( sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + ( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded)) sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None long_lora_indices = None
long_lora_indices_len: Optional[int] = None long_lora_indices_len: Optional[int] = None
@ -183,7 +184,7 @@ class PunicaWrapper:
""" """
def __init__(self, max_num_batched_tokens: int, max_batches: int, def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: str): device: Union[torch.device, str]):
self._token_lora_indices = torch.empty(max_num_batched_tokens, self._token_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long, dtype=torch.long,
device=device) device=device)
@ -215,6 +216,7 @@ class PunicaWrapper:
self._lora_indices_per_batch = torch.empty(max_batches, self._lora_indices_per_batch = torch.empty(max_batches,
dtype=torch.long, dtype=torch.long,
device=device) device=device)
self.device: torch.device = device
self.max_length: int = 0 self.max_length: int = 0
self.token_nums: int = 0 self.token_nums: int = 0
self.batch_size: int = -1 self.batch_size: int = -1
@ -263,6 +265,7 @@ class PunicaWrapper:
max_loras, max_loras,
vocab_size, vocab_size,
extra_vocab_size, extra_vocab_size,
self.device,
long_lora_context, long_lora_context,
) )
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)

View File

@ -73,6 +73,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
lora_config=self.lora_config, lora_config=self.lora_config,
device=self.device,
lora_manager_cls=self._manager_cls, lora_manager_cls=self._manager_cls,
) )
self._adapter_manager = lora_manager self._adapter_manager = lora_manager
@ -176,6 +177,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
max_num_seqs=self.max_num_seqs, max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
lora_config=self.lora_config, lora_config=self.lora_config,
device=self.device,
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
) )
self._adapter_manager = lora_manager self._adapter_manager = lora_manager