[Misc][LoRA] Replace hardcoded cuda device with configurable argument (#10223)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
eea55cca5b
commit
7f5edb5900
@ -51,6 +51,7 @@ TOLERANCES = {
|
|||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
]
|
]
|
||||||
|
|
||||||
# We will launch different triton kernels between the prefill and decode
|
# We will launch different triton kernels between the prefill and decode
|
||||||
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
|
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
|
||||||
STAGES = [True, False]
|
STAGES = [True, False]
|
||||||
@ -120,7 +121,8 @@ def populate_loras(
|
|||||||
subloras: List[LoRALayerWeights] = []
|
subloras: List[LoRALayerWeights] = []
|
||||||
sublora_len = layer_weights.shape[0] // repeats
|
sublora_len = layer_weights.shape[0] // repeats
|
||||||
for i in range(repeats):
|
for i in range(repeats):
|
||||||
sublora = DummyLoRAManager().init_random_lora(
|
sublora = DummyLoRAManager(
|
||||||
|
layer_weights.device).init_random_lora(
|
||||||
module_name=f"fake_{i}",
|
module_name=f"fake_{i}",
|
||||||
weight=layer_weights,
|
weight=layer_weights,
|
||||||
generate_embeddings_tensor=generate_embeddings_tensor,
|
generate_embeddings_tensor=generate_embeddings_tensor,
|
||||||
@ -152,6 +154,7 @@ def create_random_inputs(
|
|||||||
input_size: Tuple[int, ...],
|
input_size: Tuple[int, ...],
|
||||||
input_range: Tuple[float, float],
|
input_range: Tuple[float, float],
|
||||||
input_type: torch.dtype = torch.int,
|
input_type: torch.dtype = torch.int,
|
||||||
|
device: torch.device = "cuda"
|
||||||
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
|
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
|
||||||
"""Creates random inputs.
|
"""Creates random inputs.
|
||||||
|
|
||||||
@ -173,10 +176,14 @@ def create_random_inputs(
|
|||||||
for _ in range(num_inputs):
|
for _ in range(num_inputs):
|
||||||
if input_type == torch.int:
|
if input_type == torch.int:
|
||||||
inputs.append(
|
inputs.append(
|
||||||
torch.randint(low=int(low), high=int(high), size=input_size))
|
torch.randint(low=int(low),
|
||||||
|
high=int(high),
|
||||||
|
size=input_size,
|
||||||
|
device=device))
|
||||||
else:
|
else:
|
||||||
inputs.append(
|
inputs.append(
|
||||||
torch.rand(size=input_size, dtype=input_type) * high + low)
|
torch.rand(size=input_size, dtype=input_type, device=device) *
|
||||||
|
high + low)
|
||||||
|
|
||||||
lora_id = random.choice(active_lora_ids)
|
lora_id = random.choice(active_lora_ids)
|
||||||
index_mapping += [lora_id] * input_size[0]
|
index_mapping += [lora_id] * input_size[0]
|
||||||
@ -191,6 +198,10 @@ def create_random_inputs(
|
|||||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||||
@pytest.mark.parametrize("stage", STAGES)
|
@pytest.mark.parametrize("stage", STAGES)
|
||||||
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||||
|
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
|
||||||
|
# device, see: https://github.com/triton-lang/triton/issues/2925
|
||||||
|
# Same below.
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
@ -225,7 +236,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
|||||||
num_inputs=num_loras * 3,
|
num_inputs=num_loras * 3,
|
||||||
input_size=(200, ),
|
input_size=(200, ),
|
||||||
input_range=(1, vocab_size),
|
input_range=(1, vocab_size),
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -263,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
|||||||
num_inputs=num_loras * 3,
|
num_inputs=num_loras * 3,
|
||||||
input_size=(200, ),
|
input_size=(200, ),
|
||||||
input_range=(1, vocab_size),
|
input_range=(1, vocab_size),
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -291,6 +302,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
|||||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||||
vocab_size, stage) -> None:
|
vocab_size, stage) -> None:
|
||||||
|
|
||||||
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||||
@ -345,7 +357,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||||||
num_inputs=num_loras * 3,
|
num_inputs=num_loras * 3,
|
||||||
input_size=(200, ),
|
input_size=(200, ),
|
||||||
input_range=(1, vocab_size),
|
input_range=(1, vocab_size),
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -400,7 +412,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||||||
num_inputs=num_loras * 3,
|
num_inputs=num_loras * 3,
|
||||||
input_size=(200, ),
|
input_size=(200, ),
|
||||||
input_range=(1, vocab_size),
|
input_range=(1, vocab_size),
|
||||||
)
|
device=device)
|
||||||
original_inputs = deepcopy(inputs)
|
original_inputs = deepcopy(inputs)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
@ -426,6 +438,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||||||
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||||
stage) -> None:
|
stage) -> None:
|
||||||
|
|
||||||
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||||
@ -471,7 +484,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
|||||||
input_size=(1, 1024),
|
input_size=(1, 1024),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float16,
|
input_type=torch.float16,
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -520,7 +533,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
|||||||
input_size=(1, 1024),
|
input_size=(1, 1024),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float16,
|
input_type=torch.float16,
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -554,6 +567,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
|||||||
@pytest.mark.parametrize("stage", STAGES)
|
@pytest.mark.parametrize("stage", STAGES)
|
||||||
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
||||||
|
|
||||||
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
@ -592,7 +606,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
|||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float16,
|
input_type=torch.float16,
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -631,7 +645,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
|||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float16,
|
input_type=torch.float16,
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -658,6 +672,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
|||||||
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||||
device, stage) -> None:
|
device, stage) -> None:
|
||||||
|
|
||||||
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
@ -706,7 +721,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float16,
|
input_type=torch.float16,
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -745,7 +760,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float16,
|
input_type=torch.float16,
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -772,6 +787,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||||
device, stage) -> None:
|
device, stage) -> None:
|
||||||
|
|
||||||
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
@ -842,7 +858,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float16,
|
input_type=torch.float16,
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -883,7 +899,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float16,
|
input_type=torch.float16,
|
||||||
)
|
device=device)
|
||||||
lora_mapping = LoRAMapping(index_mapping,
|
lora_mapping = LoRAMapping(index_mapping,
|
||||||
prompt_mapping,
|
prompt_mapping,
|
||||||
is_prefill=stage)
|
is_prefill=stage)
|
||||||
@ -962,7 +978,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
|||||||
input_size=(1, max_position),
|
input_size=(1, max_position),
|
||||||
input_range=(0, lora_config.lora_extra_vocab_size),
|
input_range=(0, lora_config.lora_extra_vocab_size),
|
||||||
input_type=torch.float16,
|
input_type=torch.float16,
|
||||||
)
|
device=device)
|
||||||
|
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
long_lora_context = LongContextLoRAContext(list(scaling_factors),
|
long_lora_context = LongContextLoRAContext(list(scaling_factors),
|
||||||
|
@ -25,8 +25,13 @@ EMBEDDING_MODULES = {
|
|||||||
|
|
||||||
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
||||||
|
|
||||||
|
CUDA_DEVICES = [
|
||||||
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
|
]
|
||||||
|
|
||||||
def test_from_lora_tensors(sql_lora_files):
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_from_lora_tensors(sql_lora_files, device):
|
||||||
tensors = load_file(
|
tensors = load_file(
|
||||||
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||||
new_embeddings = load_file(
|
new_embeddings = load_file(
|
||||||
@ -36,7 +41,7 @@ def test_from_lora_tensors(sql_lora_files):
|
|||||||
8,
|
8,
|
||||||
16,
|
16,
|
||||||
tensors,
|
tensors,
|
||||||
"cuda",
|
device,
|
||||||
embeddings=new_embeddings,
|
embeddings=new_embeddings,
|
||||||
embedding_modules=EMBEDDING_MODULES,
|
embedding_modules=EMBEDDING_MODULES,
|
||||||
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
|
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
|
||||||
@ -46,6 +51,8 @@ def test_from_lora_tensors(sql_lora_files):
|
|||||||
assert lora.lora_alpha == 16
|
assert lora.lora_alpha == 16
|
||||||
assert lora.lora_a is not None
|
assert lora.lora_a is not None
|
||||||
assert lora.lora_b is not None
|
assert lora.lora_b is not None
|
||||||
|
assert lora.lora_a.device == torch.device(device)
|
||||||
|
assert lora.lora_b.device == torch.device(device)
|
||||||
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
|
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
|
||||||
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||||
assert lora.lora_a.shape[1] == 8
|
assert lora.lora_a.shape[1] == 8
|
||||||
@ -60,8 +67,8 @@ def test_from_lora_tensors(sql_lora_files):
|
|||||||
assert lora.embeddings_tensor is None
|
assert lora.embeddings_tensor is None
|
||||||
|
|
||||||
|
|
||||||
def create_lora(lora_id: int, model: nn.Module,
|
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
|
||||||
sub_modules: List[str]) -> LoRAModel:
|
device: torch.device) -> LoRAModel:
|
||||||
loras: Dict[str, LoRALayerWeights] = {}
|
loras: Dict[str, LoRALayerWeights] = {}
|
||||||
for name in sub_modules:
|
for name in sub_modules:
|
||||||
w = model.get_submodule(name).weight
|
w = model.get_submodule(name).weight
|
||||||
@ -69,8 +76,8 @@ def create_lora(lora_id: int, model: nn.Module,
|
|||||||
name,
|
name,
|
||||||
8,
|
8,
|
||||||
16,
|
16,
|
||||||
torch.rand([w.shape[1], 8], device="cuda"),
|
torch.rand([w.shape[1], 8], device=device),
|
||||||
torch.rand([8, w.shape[0]], device="cuda"),
|
torch.rand([8, w.shape[0]], device=device),
|
||||||
)
|
)
|
||||||
return LoRAModel(lora_id, 8, loras)
|
return LoRAModel(lora_id, 8, loras)
|
||||||
|
|
||||||
@ -80,6 +87,7 @@ def create_packed_lora(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
module_name,
|
module_name,
|
||||||
replaced_module_names,
|
replaced_module_names,
|
||||||
|
device: torch.device,
|
||||||
empty_replaced_module_name=None,
|
empty_replaced_module_name=None,
|
||||||
) -> LoRAModel:
|
) -> LoRAModel:
|
||||||
w = model.get_submodule(module_name).weight
|
w = model.get_submodule(module_name).weight
|
||||||
@ -91,9 +99,9 @@ def create_packed_lora(
|
|||||||
replaced_module_name,
|
replaced_module_name,
|
||||||
8,
|
8,
|
||||||
16,
|
16,
|
||||||
torch.rand([w.shape[1], 8], device="cuda"),
|
torch.rand([w.shape[1], 8], device=device),
|
||||||
torch.rand([8, w.shape[0] // len(replaced_module_names)],
|
torch.rand([8, w.shape[0] // len(replaced_module_names)],
|
||||||
device="cuda"),
|
device=device),
|
||||||
)
|
)
|
||||||
return LoRAModel(lora_id, 8, loras)
|
return LoRAModel(lora_id, 8, loras)
|
||||||
|
|
||||||
@ -104,7 +112,8 @@ def test_replace_submodules(dist_init, dummy_model):
|
|||||||
model.packed_modules_mapping = {}
|
model.packed_modules_mapping = {}
|
||||||
manager = LoRAModelManager(
|
manager = LoRAModelManager(
|
||||||
model, 1, 1, 1,
|
model, 1, 1, 1,
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
|
||||||
|
torch.device("cuda"))
|
||||||
model = manager.model
|
model = manager.model
|
||||||
|
|
||||||
assert isinstance(model.get_submodule("dense1"),
|
assert isinstance(model.get_submodule("dense1"),
|
||||||
@ -116,16 +125,28 @@ def test_replace_submodules(dist_init, dummy_model):
|
|||||||
RowParallelLinearWithLoRA)
|
RowParallelLinearWithLoRA)
|
||||||
|
|
||||||
|
|
||||||
def test_lora_model_manager(dist_init, dummy_model):
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_lora_model_manager(dist_init, dummy_model, device):
|
||||||
model = dummy_model
|
model = dummy_model
|
||||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||||
model.packed_modules_mapping = {}
|
model.packed_modules_mapping = {}
|
||||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
model_lora1 = create_lora(1,
|
||||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
device=device)
|
||||||
manager = LoRAModelManager(
|
model_lora2 = create_lora(2,
|
||||||
model, 2, 2, 2,
|
model, ["dense1", "dense2", "lm_head"],
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
device=device)
|
||||||
|
model_lora3 = create_lora(3,
|
||||||
|
model, ["dense1", "dense2", "lm_head"],
|
||||||
|
device=device)
|
||||||
|
manager = LoRAModelManager(model,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
LoRAConfig(max_lora_rank=8,
|
||||||
|
max_cpu_loras=3,
|
||||||
|
max_loras=2),
|
||||||
|
device=device)
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
assert manager.add_adapter(model_lora1)
|
assert manager.add_adapter(model_lora1)
|
||||||
assert manager.activate_adapter(1)
|
assert manager.activate_adapter(1)
|
||||||
@ -161,17 +182,32 @@ def test_lora_model_manager(dist_init, dummy_model):
|
|||||||
assert manager.lora_index_to_id[0] == 3
|
assert manager.lora_index_to_id[0] == 3
|
||||||
assert manager.lora_index_to_id[1] == 2
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
assert manager.device == device
|
||||||
|
assert manager.punica_wrapper.device == device
|
||||||
|
|
||||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||||
model = dummy_model
|
model = dummy_model
|
||||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||||
model.packed_modules_mapping = {}
|
model.packed_modules_mapping = {}
|
||||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
model_lora1 = create_lora(1,
|
||||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
device=device)
|
||||||
manager = LRUCacheLoRAModelManager(
|
model_lora2 = create_lora(2,
|
||||||
model, 2, 2, 2,
|
model, ["dense1", "dense2", "lm_head"],
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
device=device)
|
||||||
|
model_lora3 = create_lora(3,
|
||||||
|
model, ["dense1", "dense2", "lm_head"],
|
||||||
|
device=device)
|
||||||
|
manager = LRUCacheLoRAModelManager(model,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
LoRAConfig(max_lora_rank=8,
|
||||||
|
max_cpu_loras=3,
|
||||||
|
max_loras=2),
|
||||||
|
device=device)
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
assert manager.add_adapter(model_lora1)
|
assert manager.add_adapter(model_lora1)
|
||||||
assert manager.activate_adapter(1)
|
assert manager.activate_adapter(1)
|
||||||
@ -238,20 +274,37 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
assert manager.pin_adapter(3)
|
assert manager.pin_adapter(3)
|
||||||
|
|
||||||
|
assert manager.punica_wrapper.device == device
|
||||||
|
assert manager.device == device
|
||||||
|
|
||||||
def test_lru_lora_model_manager(dist_init, dummy_model):
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||||
# This tests just the LRU cache functionality, everything else is
|
# This tests just the LRU cache functionality, everything else is
|
||||||
# tested in test_lora_model_manager
|
# tested in test_lora_model_manager
|
||||||
model = dummy_model
|
model = dummy_model
|
||||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||||
model.packed_modules_mapping = {}
|
model.packed_modules_mapping = {}
|
||||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
model_lora1 = create_lora(1,
|
||||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
device=device)
|
||||||
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
|
model_lora2 = create_lora(2,
|
||||||
manager = LRUCacheLoRAModelManager(
|
model, ["dense1", "dense2", "lm_head"],
|
||||||
model, 2, 2, 2,
|
device=device)
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
model_lora3 = create_lora(3,
|
||||||
|
model, ["dense1", "dense2", "lm_head"],
|
||||||
|
device=device)
|
||||||
|
model_lora4 = create_lora(4,
|
||||||
|
model, ["dense1", "dense2", "lm_head"],
|
||||||
|
device=device)
|
||||||
|
manager = LRUCacheLoRAModelManager(model,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
LoRAConfig(max_lora_rank=8,
|
||||||
|
max_cpu_loras=2,
|
||||||
|
max_loras=2),
|
||||||
|
device=device)
|
||||||
|
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
@ -351,14 +404,17 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
|
|||||||
assert manager.remove_oldest_adapter()
|
assert manager.remove_oldest_adapter()
|
||||||
|
|
||||||
assert set(manager.list_adapters()) == {1}
|
assert set(manager.list_adapters()) == {1}
|
||||||
|
assert manager.punica_wrapper.device == device
|
||||||
|
assert manager.device == device
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||||
sql_lora_files):
|
sql_lora_files, device):
|
||||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||||
worker_adapter_manager.create_lora_manager(
|
worker_adapter_manager.create_lora_manager(
|
||||||
llama_2_7b_model_extra_embeddings)
|
llama_2_7b_model_extra_embeddings)
|
||||||
@ -426,14 +482,19 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
LoRARequest("14", 14, sql_lora_files)
|
LoRARequest("14", 14, sql_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
|
|
||||||
|
assert worker_adapter_manager.device == device
|
||||||
|
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
|
||||||
|
device)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||||
sql_lora_files):
|
sql_lora_files, device):
|
||||||
# Should remove every LoRA not specified in the request.
|
# Should remove every LoRA not specified in the request.
|
||||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||||
worker_adapter_manager = WorkerLoRAManager(
|
worker_adapter_manager = WorkerLoRAManager(
|
||||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||||
worker_adapter_manager.create_lora_manager(
|
worker_adapter_manager.create_lora_manager(
|
||||||
llama_2_7b_model_extra_embeddings)
|
llama_2_7b_model_extra_embeddings)
|
||||||
@ -497,8 +558,13 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
LoRARequest("14", 14, sql_lora_files)
|
LoRARequest("14", 14, sql_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
|
|
||||||
|
assert worker_adapter_manager.device == device
|
||||||
|
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
|
||||||
|
device)
|
||||||
|
|
||||||
def test_packed_loras(dist_init, dummy_model_gate_up):
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||||
model = dummy_model_gate_up
|
model = dummy_model_gate_up
|
||||||
model.supported_lora_modules = ["gate_up_proj"]
|
model.supported_lora_modules = ["gate_up_proj"]
|
||||||
model.packed_modules_mapping = {
|
model.packed_modules_mapping = {
|
||||||
@ -511,18 +577,25 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
|
|||||||
1,
|
1,
|
||||||
model,
|
model,
|
||||||
module_name="gate_up_proj",
|
module_name="gate_up_proj",
|
||||||
replaced_module_names=["gate_proj", "up_proj"])
|
replaced_module_names=["gate_proj", "up_proj"],
|
||||||
|
device=device)
|
||||||
model_lora1 = create_packed_lora(
|
model_lora1 = create_packed_lora(
|
||||||
2,
|
2,
|
||||||
model,
|
model,
|
||||||
module_name="gate_up_proj",
|
module_name="gate_up_proj",
|
||||||
replaced_module_names=["gate_proj", "up_proj"],
|
replaced_module_names=["gate_proj", "up_proj"],
|
||||||
|
device=device,
|
||||||
empty_replaced_module_name="gate_proj",
|
empty_replaced_module_name="gate_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
manager = LoRAModelManager(
|
manager = LoRAModelManager(model,
|
||||||
model, 2, 2, 2,
|
2,
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
2,
|
||||||
|
2,
|
||||||
|
LoRAConfig(max_lora_rank=8,
|
||||||
|
max_cpu_loras=2,
|
||||||
|
max_loras=2),
|
||||||
|
device=device)
|
||||||
model = manager.model
|
model = manager.model
|
||||||
|
|
||||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||||
|
@ -7,9 +7,10 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
|||||||
|
|
||||||
class DummyLoRAManager:
|
class DummyLoRAManager:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, device: torch.device = "cuda:0"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._loras: Dict[str, LoRALayerWeights] = {}
|
self._loras: Dict[str, LoRALayerWeights] = {}
|
||||||
|
self._device = device
|
||||||
|
|
||||||
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
|
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
|
||||||
self._loras[module_name] = lora
|
self._loras[module_name] = lora
|
||||||
@ -28,16 +29,16 @@ class DummyLoRAManager:
|
|||||||
lora_alpha=1,
|
lora_alpha=1,
|
||||||
lora_a=torch.rand([weight.shape[1], rank],
|
lora_a=torch.rand([weight.shape[1], rank],
|
||||||
dtype=weight.dtype,
|
dtype=weight.dtype,
|
||||||
device="cuda"),
|
device=self._device),
|
||||||
lora_b=torch.rand([rank, weight.shape[0]],
|
lora_b=torch.rand([rank, weight.shape[0]],
|
||||||
dtype=weight.dtype,
|
dtype=weight.dtype,
|
||||||
device="cuda"),
|
device=self._device),
|
||||||
)
|
)
|
||||||
if generate_embeddings_tensor:
|
if generate_embeddings_tensor:
|
||||||
lora.embeddings_tensor = torch.rand(5,
|
lora.embeddings_tensor = torch.rand(5,
|
||||||
generate_embeddings_tensor,
|
generate_embeddings_tensor,
|
||||||
dtype=weight.dtype,
|
dtype=weight.dtype,
|
||||||
device="cuda")
|
device=self._device)
|
||||||
self.set_module_lora(module_name, lora)
|
self.set_module_lora(module_name, lora)
|
||||||
|
|
||||||
return lora
|
return lora
|
||||||
|
@ -301,6 +301,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
|
device: torch.device,
|
||||||
):
|
):
|
||||||
"""Create a LoRAModelManager and adapter for a given model.
|
"""Create a LoRAModelManager and adapter for a given model.
|
||||||
|
|
||||||
@ -314,6 +315,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
lora_config: the LoRA configuration.
|
lora_config: the LoRA configuration.
|
||||||
"""
|
"""
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.device = device
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
assert self.capacity >= self.lora_slots
|
assert self.capacity >= self.lora_slots
|
||||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||||
@ -322,7 +324,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
||||||
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
|
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
|
||||||
max_batches=self.max_num_seqs,
|
max_batches=self.max_num_seqs,
|
||||||
device="cuda")
|
device=self.device)
|
||||||
# Scaling factor -> offset to the sin_cos_cache to it.
|
# Scaling factor -> offset to the sin_cos_cache to it.
|
||||||
# Used for long context lora.
|
# Used for long context lora.
|
||||||
self.scaling_factor_to_offset: Dict[float, int] = {}
|
self.scaling_factor_to_offset: Dict[float, int] = {}
|
||||||
@ -653,16 +655,11 @@ class LoRALRUCache(AdapterLRUCache[LoRAModel]):
|
|||||||
class LRUCacheLoRAModelManager(LoRAModelManager):
|
class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||||
"""A model manager that manages multiple LoRAs with LRU cache."""
|
"""A model manager that manages multiple LoRAs with LRU cache."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model: nn.Module, max_num_seqs: int,
|
||||||
self,
|
max_num_batched_tokens: int, vocab_size: int,
|
||||||
model: nn.Module,
|
lora_config: LoRAConfig, device: torch.device):
|
||||||
max_num_seqs: int,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
vocab_size: int,
|
|
||||||
lora_config: LoRAConfig,
|
|
||||||
):
|
|
||||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||||
vocab_size, lora_config)
|
vocab_size, lora_config, device)
|
||||||
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
||||||
self.capacity, self.deactivate_adapter)
|
self.capacity, self.deactivate_adapter)
|
||||||
self._active_adapters: LoRALRUCache = LoRALRUCache(
|
self._active_adapters: LoRALRUCache = LoRALRUCache(
|
||||||
@ -732,6 +729,7 @@ def create_lora_manager(
|
|||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
|
device: torch.device,
|
||||||
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
||||||
**kwargs) -> LoRAModelManager:
|
**kwargs) -> LoRAModelManager:
|
||||||
"""Create a LoRA adapter for a given model."""
|
"""Create a LoRA adapter for a given model."""
|
||||||
@ -743,5 +741,6 @@ def create_lora_manager(
|
|||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
|
device=device,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return lora_manager
|
return lora_manager
|
||||||
|
@ -62,6 +62,7 @@ def convert_mapping(
|
|||||||
max_loras: int,
|
max_loras: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
extra_vocab_size: int,
|
extra_vocab_size: int,
|
||||||
|
device: torch.device,
|
||||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||||
Optional[torch.Tensor], List[int]]:
|
Optional[torch.Tensor], List[int]]:
|
||||||
@ -104,7 +105,7 @@ def convert_mapping(
|
|||||||
long_lora_offsets: Optional[torch.Tensor] = None
|
long_lora_offsets: Optional[torch.Tensor] = None
|
||||||
if long_lora_context:
|
if long_lora_context:
|
||||||
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
||||||
device="cuda",
|
device=device,
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
prompt_mapping: List[int] = [
|
prompt_mapping: List[int] = [
|
||||||
lora_index_to_id.index(x) if x > 0 else -1
|
lora_index_to_id.index(x) if x > 0 else -1
|
||||||
@ -131,10 +132,10 @@ def convert_mapping(
|
|||||||
if long_lora_context:
|
if long_lora_context:
|
||||||
assert long_lora_offsets is not None
|
assert long_lora_offsets is not None
|
||||||
indices_list.append(long_lora_offsets)
|
indices_list.append(long_lora_offsets)
|
||||||
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
|
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
|
||||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||||
device="cuda",
|
dtype=torch.long,
|
||||||
dtype=torch.long)
|
device=device)
|
||||||
embeddings_indices = torch.stack([
|
embeddings_indices = torch.stack([
|
||||||
indices[2] * extra_vocab_size,
|
indices[2] * extra_vocab_size,
|
||||||
indices[2] * (vocab_size + extra_vocab_size),
|
indices[2] * (vocab_size + extra_vocab_size),
|
||||||
@ -145,7 +146,7 @@ def convert_mapping(
|
|||||||
sampler_indices_padded = sampler_indices.clone()
|
sampler_indices_padded = sampler_indices.clone()
|
||||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||||
sampler_indices_padded = torch.arange(
|
sampler_indices_padded = torch.arange(
|
||||||
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
|
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
|
||||||
sampler_indices_padded * len(sampler_indices_padded))
|
sampler_indices_padded * len(sampler_indices_padded))
|
||||||
long_lora_indices = None
|
long_lora_indices = None
|
||||||
long_lora_indices_len: Optional[int] = None
|
long_lora_indices_len: Optional[int] = None
|
||||||
@ -183,7 +184,7 @@ class PunicaWrapper:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||||
device: str):
|
device: Union[torch.device, str]):
|
||||||
self._token_lora_indices = torch.empty(max_num_batched_tokens,
|
self._token_lora_indices = torch.empty(max_num_batched_tokens,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device)
|
||||||
@ -215,6 +216,7 @@ class PunicaWrapper:
|
|||||||
self._lora_indices_per_batch = torch.empty(max_batches,
|
self._lora_indices_per_batch = torch.empty(max_batches,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device)
|
||||||
|
self.device: torch.device = device
|
||||||
self.max_length: int = 0
|
self.max_length: int = 0
|
||||||
self.token_nums: int = 0
|
self.token_nums: int = 0
|
||||||
self.batch_size: int = -1
|
self.batch_size: int = -1
|
||||||
@ -263,6 +265,7 @@ class PunicaWrapper:
|
|||||||
max_loras,
|
max_loras,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
extra_vocab_size,
|
extra_vocab_size,
|
||||||
|
self.device,
|
||||||
long_lora_context,
|
long_lora_context,
|
||||||
)
|
)
|
||||||
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||||
|
@ -73,6 +73,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
|||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
|
device=self.device,
|
||||||
lora_manager_cls=self._manager_cls,
|
lora_manager_cls=self._manager_cls,
|
||||||
)
|
)
|
||||||
self._adapter_manager = lora_manager
|
self._adapter_manager = lora_manager
|
||||||
@ -176,6 +177,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
|||||||
max_num_seqs=self.max_num_seqs,
|
max_num_seqs=self.max_num_seqs,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
|
device=self.device,
|
||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
)
|
)
|
||||||
self._adapter_manager = lora_manager
|
self._adapter_manager = lora_manager
|
||||||
|
Loading…
x
Reference in New Issue
Block a user