[Mypy] Typing lora folder (#4337)
This commit is contained in:
parent
f4bc4de1b1
commit
b5b4a398a7
7
.github/workflows/mypy.yaml
vendored
7
.github/workflows/mypy.yaml
vendored
@ -33,8 +33,6 @@ jobs:
|
|||||||
- name: Mypy
|
- name: Mypy
|
||||||
run: |
|
run: |
|
||||||
mypy vllm/attention --config-file pyproject.toml
|
mypy vllm/attention --config-file pyproject.toml
|
||||||
# TODO(sang): Fix nested dir
|
|
||||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
|
||||||
mypy vllm/distributed --config-file pyproject.toml
|
mypy vllm/distributed --config-file pyproject.toml
|
||||||
mypy vllm/entrypoints --config-file pyproject.toml
|
mypy vllm/entrypoints --config-file pyproject.toml
|
||||||
mypy vllm/executor --config-file pyproject.toml
|
mypy vllm/executor --config-file pyproject.toml
|
||||||
@ -44,8 +42,9 @@ jobs:
|
|||||||
mypy vllm/engine --config-file pyproject.toml
|
mypy vllm/engine --config-file pyproject.toml
|
||||||
mypy vllm/worker --config-file pyproject.toml
|
mypy vllm/worker --config-file pyproject.toml
|
||||||
mypy vllm/spec_decode --config-file pyproject.toml
|
mypy vllm/spec_decode --config-file pyproject.toml
|
||||||
|
mypy vllm/lora --config-file pyproject.toml
|
||||||
|
|
||||||
# TODO(sang): Fix nested dir
|
# TODO(sang): Fix nested dir
|
||||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||||
# TODO(sang): Fix nested dir
|
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
# mypy vllm/lora/*.py --config-file pyproject.toml
|
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml
|
|||||||
mypy vllm/worker --config-file pyproject.toml
|
mypy vllm/worker --config-file pyproject.toml
|
||||||
mypy vllm/spec_decode --config-file pyproject.toml
|
mypy vllm/spec_decode --config-file pyproject.toml
|
||||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||||
# mypy vllm/lora/*.py --config-file pyproject.toml
|
mypy vllm/lora --config-file pyproject.toml
|
||||||
|
|
||||||
|
|
||||||
CODESPELL_EXCLUDES=(
|
CODESPELL_EXCLUDES=(
|
||||||
|
@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_layer = base_layer
|
self.base_layer = base_layer
|
||||||
|
self.embeddings_slice: Optional[Tuple[int, int]]
|
||||||
|
self.embeddings_weights: Optional[torch.Tensor]
|
||||||
|
|
||||||
def create_lora_weights(
|
def create_lora_weights(
|
||||||
self,
|
self,
|
||||||
@ -233,9 +235,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
||||||
self.lora_a_stacked.shape[2],
|
self.lora_a_stacked.shape[2],
|
||||||
)
|
)
|
||||||
self.indices: Optional[torch.Tensor] = None
|
# Lazily initialized.
|
||||||
self.indices_len: Optional[List[int]] = None
|
self.indices: torch.Tensor
|
||||||
self.embeddings_indices = None
|
self.indices_len: List[int]
|
||||||
|
self.embeddings_indices: torch.Tensor
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[index] = 0
|
self.lora_a_stacked[index] = 0
|
||||||
@ -267,6 +270,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.embeddings_tensors.shape[1],
|
self.embeddings_tensors.shape[1],
|
||||||
self.embeddings_tensors.shape[2]
|
self.embeddings_tensors.shape[2]
|
||||||
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
|
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
|
||||||
|
assert self.embeddings_weights is not None
|
||||||
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
||||||
|
|
||||||
def set_mapping(
|
def set_mapping(
|
||||||
@ -343,11 +347,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.indices: Optional[torch.Tensor] = None
|
|
||||||
self.indices_len: Optional[List[int]] = None
|
|
||||||
self.output_dim = self.lora_b_stacked.shape[2]
|
self.output_dim = self.lora_b_stacked.shape[2]
|
||||||
|
|
||||||
|
# lazily initialized.
|
||||||
|
self.indices: torch.Tensor
|
||||||
|
self.indices_len: List[int]
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[index] = 0
|
self.lora_a_stacked[index] = 0
|
||||||
self.lora_b_stacked[index] = 0
|
self.lora_b_stacked[index] = 0
|
||||||
@ -475,8 +480,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
) for _ in range(n_slices))
|
) for _ in range(n_slices))
|
||||||
|
|
||||||
self.indices: Optional[torch.Tensor] = None
|
|
||||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||||
|
# Lazily initialized.
|
||||||
|
self.indices: torch.Tensor
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[0][index] = 0
|
self.lora_a_stacked[0][index] = 0
|
||||||
@ -690,7 +696,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
self.kv_proj_shard_size)
|
self.kv_proj_shard_size)
|
||||||
self.packed_indices: Optional[torch.Tensor] = None
|
self.packed_indices: Optional[torch.Tensor] = None
|
||||||
self.standard_indices: Optional[torch.Tensor] = None
|
self.standard_indices: Optional[torch.Tensor] = None
|
||||||
self.indices_len: Optional[List[int]] = None
|
# lazily initialized.
|
||||||
|
self.indices_len: List[int]
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[0][index] = 0
|
self.lora_a_stacked[0][index] = 0
|
||||||
@ -814,8 +821,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.indices: Optional[torch.Tensor] = None
|
# Lazily initialized
|
||||||
self.indices_len: Optional[List[int]] = None
|
self.indices: torch.Tensor
|
||||||
|
self.indices_len: List[int]
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[index] = 0
|
self.lora_a_stacked[index] = 0
|
||||||
@ -991,9 +999,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.indices = None
|
# Lazily initialized.
|
||||||
self.indices_padded = None
|
self.indices: torch.Tensor
|
||||||
self.indices_len = None
|
self.indices_len: List[int]
|
||||||
|
self.indices_padded: torch.Tensor
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[index] = 0
|
self.lora_a_stacked[index] = 0
|
||||||
|
@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
|||||||
self,
|
self,
|
||||||
module_name: str,
|
module_name: str,
|
||||||
rank: int,
|
rank: int,
|
||||||
lora_alphas: List[int],
|
lora_alphas: List[Optional[int]],
|
||||||
lora_a: List[torch.Tensor],
|
lora_a: List[Optional[torch.Tensor]],
|
||||||
lora_b: List[torch.Tensor],
|
lora_b: List[Optional[torch.Tensor]],
|
||||||
scaling: Optional[List[float]] = None,
|
scaling: Optional[List[float]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
|||||||
lora_alpha=0,
|
lora_alpha=0,
|
||||||
lora_a=lora_a,
|
lora_a=lora_a,
|
||||||
lora_b=lora_b,
|
lora_b=lora_b,
|
||||||
scaling=scaling,
|
scaling=scaling, # type: ignore
|
||||||
embeddings_tensor=None,
|
embeddings_tensor=None,
|
||||||
)
|
)
|
||||||
self.lora_alphas = lora_alphas
|
self.lora_alphas = lora_alphas
|
||||||
if scaling is None:
|
if scaling is None:
|
||||||
self.scaling = [
|
self.scaling = [ # type: ignore
|
||||||
lora_alpha / self.rank for lora_alpha in self.lora_alphas
|
lora_alpha / self.rank # type: ignore # noqa
|
||||||
|
for lora_alpha in self.lora_alphas
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
|
def pack(
|
||||||
|
cls, loras: List[Optional["LoRALayerWeights"]]
|
||||||
|
) -> "PackedLoRALayerWeights":
|
||||||
"""Pack a list of LoRAs into a single LoRA.
|
"""Pack a list of LoRAs into a single LoRA.
|
||||||
|
|
||||||
If LoRA is None, it signifies that the submodule does not have a LoRA.
|
If LoRA is None, it signifies that the submodule does not have a LoRA.
|
||||||
@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
|||||||
[lora.lora_alpha if lora is not None else None for lora in loras],
|
[lora.lora_alpha if lora is not None else None for lora in loras],
|
||||||
[lora.lora_a if lora is not None else None for lora in loras],
|
[lora.lora_a if lora is not None else None for lora in loras],
|
||||||
[lora.lora_b if lora is not None else None for lora in loras],
|
[lora.lora_b if lora is not None else None for lora in loras],
|
||||||
scaling=[1 if lora is not None else None for lora in loras])
|
scaling=[
|
||||||
|
1 if lora is not None else None # type: ignore
|
||||||
|
for lora in loras
|
||||||
|
])
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def optimize(self) -> "PackedLoRALayerWeights":
|
def optimize(self) -> "PackedLoRALayerWeights":
|
||||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||||
for i in range(len(self.lora_b)):
|
for i in range(len(self.lora_b)):
|
||||||
if self.scaling[i] == 1 or self.lora_b[i] is None:
|
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
|
||||||
continue
|
continue
|
||||||
self.lora_b[i] *= self.scaling[i]
|
self.lora_b[i] *= self.scaling[i] # type: ignore
|
||||||
self.scaling[i] = 1
|
self.scaling[i] = 1 # type: ignore
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -3,7 +3,7 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type
|
from typing import Callable, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
@ -53,44 +53,46 @@ def convert_mapping(
|
|||||||
embeddings.
|
embeddings.
|
||||||
indices_len: List of lengths of the above tensors.
|
indices_len: List of lengths of the above tensors.
|
||||||
"""
|
"""
|
||||||
indices = list(mapping.index_mapping).copy()
|
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
|
||||||
embedding_indices = indices.copy()
|
embedding_indices = index_mapping_indices.copy()
|
||||||
lora_indices = indices.copy()
|
lora_indices = index_mapping_indices.copy()
|
||||||
prompt_mapping = [
|
prompt_mapping: List[int] = [
|
||||||
lora_index_to_id.index(x) if x > 0 else -1
|
lora_index_to_id.index(x) if x > 0 else -1
|
||||||
for x in mapping.prompt_mapping
|
for x in mapping.prompt_mapping
|
||||||
]
|
]
|
||||||
lora_idx = None
|
lora_idx = None
|
||||||
for i in range(len(indices)):
|
for i in range(len(index_mapping_indices)):
|
||||||
# TODO index can be slow. optimize
|
# TODO index can be slow. optimize
|
||||||
lora_idx = (lora_index_to_id.index(indices[i])
|
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
||||||
if indices[i] > 0 else -1)
|
if index_mapping_indices[i] > 0 else -1)
|
||||||
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
|
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||||
indices[i] = i
|
index_mapping_indices[i] = i
|
||||||
lora_indices[i] = lora_idx
|
lora_indices[i] = lora_idx
|
||||||
|
|
||||||
indices = torch.tensor([indices, lora_indices, embedding_indices],
|
indices = torch.tensor(
|
||||||
dtype=torch.long,
|
[index_mapping_indices, lora_indices, embedding_indices],
|
||||||
device="cuda")
|
dtype=torch.long,
|
||||||
prompt_mapping = torch.tensor(prompt_mapping,
|
device="cuda")
|
||||||
device="cuda",
|
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||||
dtype=torch.long)
|
device="cuda",
|
||||||
|
dtype=torch.long)
|
||||||
embeddings_indices = torch.stack([
|
embeddings_indices = torch.stack([
|
||||||
indices[2] * extra_vocab_size,
|
indices[2] * extra_vocab_size,
|
||||||
indices[2] * (vocab_size + extra_vocab_size)
|
indices[2] * (vocab_size + extra_vocab_size)
|
||||||
])
|
])
|
||||||
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
||||||
base_indices = indices[1]
|
base_indices = indices[1]
|
||||||
sampler_indices = prompt_mapping
|
sampler_indices = prompt_mapping_tensor
|
||||||
sampler_indices_padded = sampler_indices.clone()
|
sampler_indices_padded = sampler_indices.clone()
|
||||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||||
sampler_indices_padded = (
|
sampler_indices_padded = (
|
||||||
torch.arange(
|
torch.arange(
|
||||||
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
|
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
|
||||||
(sampler_indices_padded * len(sampler_indices_padded)))
|
(sampler_indices_padded * len(sampler_indices_padded)))
|
||||||
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
|
indices_len = [
|
||||||
sampler_indices_padded.shape[-1],
|
base_indices.shape[-1], sampler_indices.shape[-1],
|
||||||
embeddings_indices.shape[-1])
|
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
|
||||||
|
]
|
||||||
|
|
||||||
return (base_indices, sampler_indices, sampler_indices_padded,
|
return (base_indices, sampler_indices, sampler_indices_padded,
|
||||||
embeddings_indices, indices_len)
|
embeddings_indices, indices_len)
|
||||||
@ -149,6 +151,7 @@ class LoRAModel:
|
|||||||
if module_name not in loras:
|
if module_name not in loras:
|
||||||
lora_embeddings_tensor = None
|
lora_embeddings_tensor = None
|
||||||
if embeddings:
|
if embeddings:
|
||||||
|
assert embedding_modules is not None
|
||||||
embeddings_module = next(
|
embeddings_module = next(
|
||||||
(k for k in embedding_modules if k in module_name),
|
(k for k in embedding_modules if k in module_name),
|
||||||
None)
|
None)
|
||||||
@ -171,6 +174,7 @@ class LoRAModel:
|
|||||||
else:
|
else:
|
||||||
loras[module_name].lora_b = tensor.to(device=device,
|
loras[module_name].lora_b = tensor.to(device=device,
|
||||||
dtype=dtype).t()
|
dtype=dtype).t()
|
||||||
|
assert embedding_padding_modules is not None
|
||||||
if any(name in module_name
|
if any(name in module_name
|
||||||
for name in embedding_padding_modules
|
for name in embedding_padding_modules
|
||||||
) and target_embedding_padding is not None:
|
) and target_embedding_padding is not None:
|
||||||
@ -295,11 +299,10 @@ class LoRAModelManager:
|
|||||||
self.max_num_batched_tokens,
|
self.max_num_batched_tokens,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
self.offsets = []
|
|
||||||
# 4 is the number of indicies tensors defined above
|
# 4 is the number of indicies tensors defined above
|
||||||
# base_indices, sampler_indices, sampler_indices_padded,
|
# base_indices, sampler_indices, sampler_indices_padded,
|
||||||
# embeddings_indices
|
# embeddings_indices
|
||||||
self.indices_len = [None] * 4
|
self.indices_len: List[Optional[int]] = [None] * 4
|
||||||
|
|
||||||
self.model: nn.Module = model
|
self.model: nn.Module = model
|
||||||
if hasattr(self.model, "supported_lora_modules"):
|
if hasattr(self.model, "supported_lora_modules"):
|
||||||
@ -312,7 +315,7 @@ class LoRAModelManager:
|
|||||||
self._registered_loras: Dict[int, LoRAModel] = {}
|
self._registered_loras: Dict[int, LoRAModel] = {}
|
||||||
# Dict instead of a Set for compatibility with LRUCache.
|
# Dict instead of a Set for compatibility with LRUCache.
|
||||||
self._active_loras: Dict[int, None] = {}
|
self._active_loras: Dict[int, None] = {}
|
||||||
self._last_mapping = None
|
self._last_mapping: Optional[LoRAMapping] = None
|
||||||
self._create_lora_modules()
|
self._create_lora_modules()
|
||||||
self.model.lora_manager = self
|
self.model.lora_manager = self
|
||||||
|
|
||||||
@ -370,7 +373,7 @@ class LoRAModelManager:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _add_lora(self, lora: LoRAModel) -> bool:
|
def _add_lora(self, lora: LoRAModel):
|
||||||
self._create_merged_loras_inplace(lora)
|
self._create_merged_loras_inplace(lora)
|
||||||
self._registered_loras[lora.id] = lora
|
self._registered_loras[lora.id] = lora
|
||||||
|
|
||||||
@ -418,7 +421,7 @@ class LoRAModelManager:
|
|||||||
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
|
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
|
||||||
return self._registered_loras.get(lora_id, None)
|
return self._registered_loras.get(lora_id, None)
|
||||||
|
|
||||||
def remove_all_loras(self) -> bool:
|
def remove_all_loras(self):
|
||||||
"""Remove all LoRAModels from the manager."""
|
"""Remove all LoRAModels from the manager."""
|
||||||
self._registered_loras.clear()
|
self._registered_loras.clear()
|
||||||
self.lora_index_to_id = [None] * self.lora_slots
|
self.lora_index_to_id = [None] * self.lora_slots
|
||||||
@ -467,6 +470,7 @@ class LoRAModelManager:
|
|||||||
continue
|
continue
|
||||||
parts = module_name.split(".")
|
parts = module_name.split(".")
|
||||||
if module_name not in self.packed_modules:
|
if module_name not in self.packed_modules:
|
||||||
|
assert embedding_modules is not None
|
||||||
if parts[-1] in embedding_modules:
|
if parts[-1] in embedding_modules:
|
||||||
input_dim = (module.base_layer.org_vocab_size +
|
input_dim = (module.base_layer.org_vocab_size +
|
||||||
self.lora_config.lora_extra_vocab_size if
|
self.lora_config.lora_extra_vocab_size if
|
||||||
@ -500,7 +504,7 @@ class LoRAModelManager:
|
|||||||
else:
|
else:
|
||||||
parts = module_name.split(".")
|
parts = module_name.split(".")
|
||||||
replacements = self.packed_modules_mapping[parts[-1]]
|
replacements = self.packed_modules_mapping[parts[-1]]
|
||||||
subloras = []
|
subloras: List[Optional["LoRALayerWeights"]] = []
|
||||||
for i, r in enumerate(replacements):
|
for i, r in enumerate(replacements):
|
||||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||||
module_name + "." + r,
|
module_name + "." + r,
|
||||||
@ -538,7 +542,7 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||||
for module_name, new_module_names in self.packed_modules.items():
|
for module_name, new_module_names in self.packed_modules.items():
|
||||||
replacement_loras = []
|
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
||||||
has_replacement = False
|
has_replacement = False
|
||||||
for r in new_module_names:
|
for r in new_module_names:
|
||||||
lora = lora_model.get_lora(r)
|
lora = lora_model.get_lora(r)
|
||||||
@ -557,12 +561,12 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
class LoRALRUCache(LRUCache[LoRAModel]):
|
class LoRALRUCache(LRUCache[LoRAModel]):
|
||||||
|
|
||||||
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
|
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
|
||||||
None]):
|
bool]):
|
||||||
super().__init__(capacity)
|
super().__init__(capacity)
|
||||||
self.deactivate_lora_fn = deactivate_lora_fn
|
self.deactivate_lora_fn = deactivate_lora_fn
|
||||||
|
|
||||||
def _on_remove(self, key: Hashable, value: LoRAModel):
|
def _on_remove(self, key: int, value: LoRAModel):
|
||||||
logger.debug(f"Removing LoRA. int id: {key}")
|
logger.debug(f"Removing LoRA. int id: {key}")
|
||||||
self.deactivate_lora_fn(key)
|
self.deactivate_lora_fn(key)
|
||||||
return super()._on_remove(key, value)
|
return super()._on_remove(key, value)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod, abstractproperty
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
from typing import Any, Dict, List, Optional, Set, Type
|
from typing import Any, Dict, List, Set, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_active_loras(self, lora_requests: List[LoRARequest],
|
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||||
lora_mapping: LoRAMapping) -> None:
|
lora_mapping: LoRAMapping) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove_all_loras(self) -> bool:
|
def remove_all_loras(self):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
embedding_padding_modules: List[str],
|
embedding_padding_modules: List[str],
|
||||||
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
||||||
):
|
):
|
||||||
self._lora_manager: Optional[LoRAModelManager] = None
|
|
||||||
self._lora_model_cls = lora_model_cls
|
self._lora_model_cls = lora_model_cls
|
||||||
self.embedding_modules = embedding_modules
|
self.embedding_modules = embedding_modules
|
||||||
self.embedding_padding_modules = embedding_padding_modules
|
self.embedding_padding_modules = embedding_padding_modules
|
||||||
|
# Lazily initialized by create_lora_manager.
|
||||||
|
self._lora_manager: LoRAModelManager
|
||||||
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
||||||
lora_config, device)
|
lora_config, device)
|
||||||
|
|
||||||
@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
lora_manager_cls=self._lora_manager_cls,
|
lora_manager_cls=self._lora_manager_cls,
|
||||||
)
|
)
|
||||||
self._lora_manager: LoRAModelManager = lora_manager
|
self._lora_manager = lora_manager
|
||||||
return lora_manager.model
|
return lora_manager.model
|
||||||
|
|
||||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||||
@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
return self._lora_manager.remove_lora(lora_id)
|
return self._lora_manager.remove_lora(lora_id)
|
||||||
|
|
||||||
def remove_all_loras(self) -> bool:
|
def remove_all_loras(self):
|
||||||
self._lora_manager.remove_all_loras()
|
self._lora_manager.remove_all_loras()
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
|||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
)
|
)
|
||||||
self._lora_manager: LRUCacheLoRAModelManager = lora_manager
|
self._lora_manager = lora_manager
|
||||||
return lora_manager.model
|
return lora_manager.model
|
||||||
|
|
||||||
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
|
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
||||||
loras_map = {
|
loras_map = {
|
||||||
lora_request.lora_int_id: lora_request
|
lora_request.lora_int_id: lora_request
|
||||||
for lora_request in lora_requests if lora_request
|
for lora_request in lora_requests if lora_request
|
||||||
@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
|||||||
if lora_request.lora_int_id not in self.list_loras():
|
if lora_request.lora_int_id not in self.list_loras():
|
||||||
# Remove before we load the new lora to save memory
|
# Remove before we load the new lora to save memory
|
||||||
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
|
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
|
||||||
|
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
|
||||||
self._lora_manager.remove_oldest_lora()
|
self._lora_manager.remove_oldest_lora()
|
||||||
lora = self._load_lora(lora_request)
|
lora = self._load_lora(lora_request)
|
||||||
loaded = self._lora_manager.add_lora(lora)
|
loaded = self._lora_manager.add_lora(lora)
|
||||||
else:
|
else:
|
||||||
# If the lora is already loaded, just touch it to
|
# If the lora is already loaded, just touch it to
|
||||||
# update its position in the caches
|
# update its position in the caches
|
||||||
loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
|
loaded = self._lora_manager.get_lora(
|
||||||
|
lora_request.lora_int_id) is not None
|
||||||
self._lora_manager.activate_lora(lora_request.lora_int_id)
|
self._lora_manager.activate_lora(lora_request.lora_int_id)
|
||||||
return loaded
|
return loaded
|
||||||
|
@ -928,10 +928,10 @@ class ModelRunner:
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return
|
return
|
||||||
|
|
||||||
def remove_all_loras(self) -> bool:
|
def remove_all_loras(self):
|
||||||
if not self.lora_manager:
|
if not self.lora_manager:
|
||||||
raise RuntimeError("LoRA is not enabled.")
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
return self.lora_manager.remove_all_loras()
|
self.lora_manager.remove_all_loras()
|
||||||
|
|
||||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||||
lora_mapping: LoRAMapping) -> None:
|
lora_mapping: LoRAMapping) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user