[Misc][LoRA] Replace hardcoded cuda device with configurable argument (#10223)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-12 11:10:15 +08:00 committed by GitHub
parent eea55cca5b
commit 7f5edb5900
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 174 additions and 80 deletions

View File

@ -51,6 +51,7 @@ TOLERANCES = {
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
# We will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
@ -120,7 +121,8 @@ def populate_loras(
subloras: List[LoRALayerWeights] = []
sublora_len = layer_weights.shape[0] // repeats
for i in range(repeats):
sublora = DummyLoRAManager().init_random_lora(
sublora = DummyLoRAManager(
layer_weights.device).init_random_lora(
module_name=f"fake_{i}",
weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor,
@ -152,6 +154,7 @@ def create_random_inputs(
input_size: Tuple[int, ...],
input_range: Tuple[float, float],
input_type: torch.dtype = torch.int,
device: torch.device = "cuda"
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
"""Creates random inputs.
@ -173,10 +176,14 @@ def create_random_inputs(
for _ in range(num_inputs):
if input_type == torch.int:
inputs.append(
torch.randint(low=int(low), high=int(high), size=input_size))
torch.randint(low=int(low),
high=int(high),
size=input_size,
device=device))
else:
inputs.append(
torch.rand(size=input_size, dtype=input_type) * high + low)
torch.rand(size=input_size, dtype=input_type, device=device) *
high + low)
lora_id = random.choice(active_lora_ids)
index_mapping += [lora_id] * input_size[0]
@ -191,6 +198,10 @@ def create_random_inputs(
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
# device, see: https://github.com/triton-lang/triton/issues/2925
# Same below.
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
@ -225,7 +236,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, vocab_size),
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -263,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, vocab_size),
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -291,6 +302,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size, stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
@ -345,7 +357,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, vocab_size),
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -400,7 +412,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, vocab_size),
)
device=device)
original_inputs = deepcopy(inputs)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
@ -426,6 +438,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
@ -471,7 +484,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
input_size=(1, 1024),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -520,7 +533,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
input_size=(1, 1024),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -554,6 +567,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
@ -592,7 +606,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -631,7 +645,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -658,6 +672,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
@ -706,7 +721,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -745,7 +760,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -772,6 +787,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
@ -842,7 +858,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -883,7 +899,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
@ -962,7 +978,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
input_size=(1, max_position),
input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
long_lora_context = LongContextLoRAContext(list(scaling_factors),

View File

@ -25,8 +25,13 @@ EMBEDDING_MODULES = {
EMBEDDING_PADDING_MODULES = ["lm_head"]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def test_from_lora_tensors(sql_lora_files):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
@ -36,7 +41,7 @@ def test_from_lora_tensors(sql_lora_files):
8,
16,
tensors,
"cuda",
device,
embeddings=new_embeddings,
embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
@ -46,6 +51,8 @@ def test_from_lora_tensors(sql_lora_files):
assert lora.lora_alpha == 16
assert lora.lora_a is not None
assert lora.lora_b is not None
assert lora.lora_a.device == torch.device(device)
assert lora.lora_b.device == torch.device(device)
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
assert lora.lora_a.shape[1] == 8
@ -60,8 +67,8 @@ def test_from_lora_tensors(sql_lora_files):
assert lora.embeddings_tensor is None
def create_lora(lora_id: int, model: nn.Module,
sub_modules: List[str]) -> LoRAModel:
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
device: torch.device) -> LoRAModel:
loras: Dict[str, LoRALayerWeights] = {}
for name in sub_modules:
w = model.get_submodule(name).weight
@ -69,8 +76,8 @@ def create_lora(lora_id: int, model: nn.Module,
name,
8,
16,
torch.rand([w.shape[1], 8], device="cuda"),
torch.rand([8, w.shape[0]], device="cuda"),
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0]], device=device),
)
return LoRAModel(lora_id, 8, loras)
@ -80,6 +87,7 @@ def create_packed_lora(
model: nn.Module,
module_name,
replaced_module_names,
device: torch.device,
empty_replaced_module_name=None,
) -> LoRAModel:
w = model.get_submodule(module_name).weight
@ -91,9 +99,9 @@ def create_packed_lora(
replaced_module_name,
8,
16,
torch.rand([w.shape[1], 8], device="cuda"),
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0] // len(replaced_module_names)],
device="cuda"),
device=device),
)
return LoRAModel(lora_id, 8, loras)
@ -104,7 +112,8 @@ def test_replace_submodules(dist_init, dummy_model):
model.packed_modules_mapping = {}
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device("cuda"))
model = manager.model
assert isinstance(model.get_submodule("dense1"),
@ -116,16 +125,28 @@ def test_replace_submodules(dist_init, dummy_model):
RowParallelLinearWithLoRA)
def test_lora_model_manager(dist_init, dummy_model):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
manager = LoRAModelManager(
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
model_lora2 = create_lora(2,
model, ["dense1", "dense2", "lm_head"],
device=device)
model_lora3 = create_lora(3,
model, ["dense1", "dense2", "lm_head"],
device=device)
manager = LoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
@ -161,17 +182,32 @@ def test_lora_model_manager(dist_init, dummy_model):
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
assert manager.device == device
assert manager.punica_wrapper.device == device
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
manager = LRUCacheLoRAModelManager(
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
model_lora2 = create_lora(2,
model, ["dense1", "dense2", "lm_head"],
device=device)
model_lora3 = create_lora(3,
model, ["dense1", "dense2", "lm_head"],
device=device)
manager = LRUCacheLoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
@ -238,20 +274,37 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
with pytest.raises(ValueError):
assert manager.pin_adapter(3)
assert manager.punica_wrapper.device == device
assert manager.device == device
def test_lru_lora_model_manager(dist_init, dummy_model):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
manager = LRUCacheLoRAModelManager(
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
model_lora2 = create_lora(2,
model, ["dense1", "dense2", "lm_head"],
device=device)
model_lora3 = create_lora(3,
model, ["dense1", "dense2", "lm_head"],
device=device)
model_lora4 = create_lora(4,
model, ["dense1", "dense2", "lm_head"],
device=device)
manager = LRUCacheLoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
@ -351,14 +404,17 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1}
assert manager.punica_wrapper.device == device
assert manager.device == device
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
lora_config.lora_extra_vocab_size, lora_config, device,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
@ -426,14 +482,19 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
LoRARequest("14", 14, sql_lora_files)
], mapping)
assert worker_adapter_manager.device == device
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
device)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
sql_lora_files, device):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
lora_config.lora_extra_vocab_size, lora_config, device,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
@ -497,8 +558,13 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
LoRARequest("14", 14, sql_lora_files)
], mapping)
assert worker_adapter_manager.device == device
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
device)
def test_packed_loras(dist_init, dummy_model_gate_up):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"]
model.packed_modules_mapping = {
@ -511,18 +577,25 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
1,
model,
module_name="gate_up_proj",
replaced_module_names=["gate_proj", "up_proj"])
replaced_module_names=["gate_proj", "up_proj"],
device=device)
model_lora1 = create_packed_lora(
2,
model,
module_name="gate_up_proj",
replaced_module_names=["gate_proj", "up_proj"],
device=device,
empty_replaced_module_name="gate_proj",
)
manager = LoRAModelManager(
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
manager = LoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
device=device)
model = manager.model
assert isinstance(model.get_submodule("gate_up_proj"),

View File

@ -7,9 +7,10 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
class DummyLoRAManager:
def __init__(self):
def __init__(self, device: torch.device = "cuda:0"):
super().__init__()
self._loras: Dict[str, LoRALayerWeights] = {}
self._device = device
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
self._loras[module_name] = lora
@ -28,16 +29,16 @@ class DummyLoRAManager:
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
dtype=weight.dtype,
device="cuda"),
device=self._device),
lora_b=torch.rand([rank, weight.shape[0]],
dtype=weight.dtype,
device="cuda"),
device=self._device),
)
if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(5,
generate_embeddings_tensor,
dtype=weight.dtype,
device="cuda")
device=self._device)
self.set_module_lora(module_name, lora)
return lora

View File

@ -301,6 +301,7 @@ class LoRAModelManager(AdapterModelManager):
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
):
"""Create a LoRAModelManager and adapter for a given model.
@ -314,6 +315,7 @@ class LoRAModelManager(AdapterModelManager):
lora_config: the LoRA configuration.
"""
self.lora_config = lora_config
self.device = device
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
@ -322,7 +324,7 @@ class LoRAModelManager(AdapterModelManager):
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device="cuda")
device=self.device)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
@ -653,16 +655,11 @@ class LoRALRUCache(AdapterLRUCache[LoRAModel]):
class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
):
def __init__(self, model: nn.Module, max_num_seqs: int,
max_num_batched_tokens: int, vocab_size: int,
lora_config: LoRAConfig, device: torch.device):
super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config)
vocab_size, lora_config, device)
self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter)
self._active_adapters: LoRALRUCache = LoRALRUCache(
@ -732,6 +729,7 @@ def create_lora_manager(
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
@ -743,5 +741,6 @@ def create_lora_manager(
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
device=device,
**kwargs)
return lora_manager

View File

@ -62,6 +62,7 @@ def convert_mapping(
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
@ -104,7 +105,7 @@ def convert_mapping(
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device="cuda",
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
@ -131,10 +132,10 @@ def convert_mapping(
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
dtype=torch.long,
device=device)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
@ -145,7 +146,7 @@ def convert_mapping(
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
@ -183,7 +184,7 @@ class PunicaWrapper:
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: str):
device: Union[torch.device, str]):
self._token_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
@ -215,6 +216,7 @@ class PunicaWrapper:
self._lora_indices_per_batch = torch.empty(max_batches,
dtype=torch.long,
device=device)
self.device: torch.device = device
self.max_length: int = 0
self.token_nums: int = 0
self.batch_size: int = -1
@ -263,6 +265,7 @@ class PunicaWrapper:
max_loras,
vocab_size,
extra_vocab_size,
self.device,
long_lora_context,
)
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)

View File

@ -73,6 +73,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
max_num_batched_tokens=self.max_num_batched_tokens,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
device=self.device,
lora_manager_cls=self._manager_cls,
)
self._adapter_manager = lora_manager
@ -176,6 +177,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
device=self.device,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._adapter_manager = lora_manager