[Mypy] Typing lora folder (#4337)

This commit is contained in:
SangBin Cho 2024-04-26 04:13:50 +09:00 committed by GitHub
parent f4bc4de1b1
commit b5b4a398a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 91 additions and 70 deletions

View File

@ -33,8 +33,6 @@ jobs:
- name: Mypy
run: |
mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
@ -44,8 +42,9 @@ jobs:
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
CODESPELL_EXCLUDES=(

View File

@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]
def create_lora_weights(
self,
@ -233,9 +235,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.embeddings_indices = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.embeddings_indices: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
@ -267,6 +270,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping(
@ -343,11 +347,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[2]
# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
@ -475,8 +480,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
device=self.device,
) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2]
# Lazily initialized.
self.indices: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
@ -690,7 +696,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
# lazily initialized.
self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
@ -814,8 +821,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
# Lazily initialized
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
@ -991,9 +999,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype=self.dtype,
device=self.device,
)
self.indices = None
self.indices_padded = None
self.indices_len = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.indices_padded: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0

View File

@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights):
self,
module_name: str,
rank: int,
lora_alphas: List[int],
lora_a: List[torch.Tensor],
lora_b: List[torch.Tensor],
lora_alphas: List[Optional[int]],
lora_a: List[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]],
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling,
scaling=scaling, # type: ignore
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [
lora_alpha / self.rank for lora_alpha in self.lora_alphas
self.scaling = [ # type: ignore
lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
]
@classmethod
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
def pack(
cls, loras: List[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[1 if lora is not None else None for lora in loras])
scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
])
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None:
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue
self.lora_b[i] *= self.scaling[i]
self.scaling[i] = 1
self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 # type: ignore
return self
@property

View File

@ -3,7 +3,7 @@ import json
import math
import os
import re
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type
from typing import Callable, Dict, List, Optional, Tuple, Type
import safetensors.torch
import torch
@ -53,44 +53,46 @@ def convert_mapping(
embeddings.
indices_len: List of lengths of the above tensors.
"""
indices = list(mapping.index_mapping).copy()
embedding_indices = indices.copy()
lora_indices = indices.copy()
prompt_mapping = [
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(indices)):
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(indices[i])
if indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
indices[i] = i
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
index_mapping_indices[i] = i
lora_indices[i] = lora_idx
indices = torch.tensor([indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
indices = torch.tensor(
[index_mapping_indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1])
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
@ -149,6 +151,7 @@ class LoRAModel:
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
assert embedding_modules is not None
embeddings_module = next(
(k for k in embedding_modules if k in module_name),
None)
@ -171,6 +174,7 @@ class LoRAModel:
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
assert embedding_padding_modules is not None
if any(name in module_name
for name in embedding_padding_modules
) and target_embedding_padding is not None:
@ -295,11 +299,10 @@ class LoRAModelManager:
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.offsets = []
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len = [None] * 4
self.indices_len: List[Optional[int]] = [None] * 4
self.model: nn.Module = model
if hasattr(self.model, "supported_lora_modules"):
@ -312,7 +315,7 @@ class LoRAModelManager:
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping = None
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
@ -370,7 +373,7 @@ class LoRAModelManager:
return True
return False
def _add_lora(self, lora: LoRAModel) -> bool:
def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
@ -418,7 +421,7 @@ class LoRAModelManager:
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
def remove_all_loras(self) -> bool:
def remove_all_loras(self):
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots
@ -467,6 +470,7 @@ class LoRAModelManager:
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
@ -500,7 +504,7 @@ class LoRAModelManager:
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras = []
subloras: List[Optional["LoRALayerWeights"]] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
@ -538,7 +542,7 @@ class LoRAModelManager:
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras = []
replacement_loras: List[Optional[LoRALayerWeights]] = []
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
@ -557,12 +561,12 @@ class LoRAModelManager:
class LoRALRUCache(LRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
bool]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: LoRAModel):
def _on_remove(self, key: int, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Optional, Set, Type
from typing import Any, Dict, List, Set, Type
import torch
@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC):
...
@abstractmethod
def set_active_loras(self, lora_requests: List[LoRARequest],
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC):
...
@abstractmethod
def remove_all_loras(self) -> bool:
def remove_all_loras(self):
...
@abstractmethod
@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel,
):
self._lora_manager: Optional[LoRAModelManager] = None
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
# Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device)
@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls,
)
self._lora_manager: LoRAModelManager = lora_manager
self._lora_manager = lora_manager
return lora_manager.model
def set_active_loras(self, lora_requests: Set[LoRARequest],
@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)
def remove_all_loras(self) -> bool:
def remove_all_loras(self):
self._lora_manager.remove_all_loras()
def list_loras(self) -> Set[int]:
@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._lora_manager: LRUCacheLoRAModelManager = lora_manager
self._lora_manager = lora_manager
return lora_manager.model
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
if lora_request.lora_int_id not in self.list_loras():
# Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
loaded = self._lora_manager.get_lora(
lora_request.lora_int_id) is not None
self._lora_manager.activate_lora(lora_request.lora_int_id)
return loaded

View File

@ -928,10 +928,10 @@ class ModelRunner:
torch.cuda.synchronize()
return
def remove_all_loras(self) -> bool:
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras()
self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: