[Mypy] Typing lora folder (#4337)
This commit is contained in:
parent
f4bc4de1b1
commit
b5b4a398a7
7
.github/workflows/mypy.yaml
vendored
7
.github/workflows/mypy.yaml
vendored
@ -33,8 +33,6 @@ jobs:
|
||||
- name: Mypy
|
||||
run: |
|
||||
mypy vllm/attention --config-file pyproject.toml
|
||||
# TODO(sang): Fix nested dir
|
||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
@ -44,8 +42,9 @@ jobs:
|
||||
mypy vllm/engine --config-file pyproject.toml
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/lora --config-file pyproject.toml
|
||||
|
||||
# TODO(sang): Fix nested dir
|
||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||
# TODO(sang): Fix nested dir
|
||||
# mypy vllm/lora/*.py --config-file pyproject.toml
|
||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
|
@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||
# mypy vllm/lora/*.py --config-file pyproject.toml
|
||||
mypy vllm/lora --config-file pyproject.toml
|
||||
|
||||
|
||||
CODESPELL_EXCLUDES=(
|
||||
|
@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.embeddings_slice: Optional[Tuple[int, int]]
|
||||
self.embeddings_weights: Optional[torch.Tensor]
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
@ -233,9 +235,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
||||
self.lora_a_stacked.shape[2],
|
||||
)
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
self.embeddings_indices = None
|
||||
# Lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
self.embeddings_indices: torch.Tensor
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
@ -267,6 +270,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
self.embeddings_tensors.shape[1],
|
||||
self.embeddings_tensors.shape[2]
|
||||
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
|
||||
assert self.embeddings_weights is not None
|
||||
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
||||
|
||||
def set_mapping(
|
||||
@ -343,11 +347,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
self.output_dim = self.lora_b_stacked.shape[2]
|
||||
|
||||
# lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
@ -475,8 +480,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
device=self.device,
|
||||
) for _ in range(n_slices))
|
||||
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||
# Lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[0][index] = 0
|
||||
@ -690,7 +696,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.kv_proj_shard_size)
|
||||
self.packed_indices: Optional[torch.Tensor] = None
|
||||
self.standard_indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
# lazily initialized.
|
||||
self.indices_len: List[int]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[0][index] = 0
|
||||
@ -814,8 +821,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
# Lazily initialized
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
@ -991,9 +999,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.indices = None
|
||||
self.indices_padded = None
|
||||
self.indices_len = None
|
||||
# Lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
self.indices_padded: torch.Tensor
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
|
@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
self,
|
||||
module_name: str,
|
||||
rank: int,
|
||||
lora_alphas: List[int],
|
||||
lora_a: List[torch.Tensor],
|
||||
lora_b: List[torch.Tensor],
|
||||
lora_alphas: List[Optional[int]],
|
||||
lora_a: List[Optional[torch.Tensor]],
|
||||
lora_b: List[Optional[torch.Tensor]],
|
||||
scaling: Optional[List[float]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
lora_alpha=0,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
scaling=scaling,
|
||||
scaling=scaling, # type: ignore
|
||||
embeddings_tensor=None,
|
||||
)
|
||||
self.lora_alphas = lora_alphas
|
||||
if scaling is None:
|
||||
self.scaling = [
|
||||
lora_alpha / self.rank for lora_alpha in self.lora_alphas
|
||||
self.scaling = [ # type: ignore
|
||||
lora_alpha / self.rank # type: ignore # noqa
|
||||
for lora_alpha in self.lora_alphas
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
|
||||
def pack(
|
||||
cls, loras: List[Optional["LoRALayerWeights"]]
|
||||
) -> "PackedLoRALayerWeights":
|
||||
"""Pack a list of LoRAs into a single LoRA.
|
||||
|
||||
If LoRA is None, it signifies that the submodule does not have a LoRA.
|
||||
@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
[lora.lora_alpha if lora is not None else None for lora in loras],
|
||||
[lora.lora_a if lora is not None else None for lora in loras],
|
||||
[lora.lora_b if lora is not None else None for lora in loras],
|
||||
scaling=[1 if lora is not None else None for lora in loras])
|
||||
scaling=[
|
||||
1 if lora is not None else None # type: ignore
|
||||
for lora in loras
|
||||
])
|
||||
return obj
|
||||
|
||||
def optimize(self) -> "PackedLoRALayerWeights":
|
||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||
for i in range(len(self.lora_b)):
|
||||
if self.scaling[i] == 1 or self.lora_b[i] is None:
|
||||
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
|
||||
continue
|
||||
self.lora_b[i] *= self.scaling[i]
|
||||
self.scaling[i] = 1
|
||||
self.lora_b[i] *= self.scaling[i] # type: ignore
|
||||
self.scaling[i] = 1 # type: ignore
|
||||
return self
|
||||
|
||||
@property
|
||||
|
@ -3,7 +3,7 @@ import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -53,44 +53,46 @@ def convert_mapping(
|
||||
embeddings.
|
||||
indices_len: List of lengths of the above tensors.
|
||||
"""
|
||||
indices = list(mapping.index_mapping).copy()
|
||||
embedding_indices = indices.copy()
|
||||
lora_indices = indices.copy()
|
||||
prompt_mapping = [
|
||||
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
|
||||
embedding_indices = index_mapping_indices.copy()
|
||||
lora_indices = index_mapping_indices.copy()
|
||||
prompt_mapping: List[int] = [
|
||||
lora_index_to_id.index(x) if x > 0 else -1
|
||||
for x in mapping.prompt_mapping
|
||||
]
|
||||
lora_idx = None
|
||||
for i in range(len(indices)):
|
||||
for i in range(len(index_mapping_indices)):
|
||||
# TODO index can be slow. optimize
|
||||
lora_idx = (lora_index_to_id.index(indices[i])
|
||||
if indices[i] > 0 else -1)
|
||||
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
|
||||
indices[i] = i
|
||||
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
||||
if index_mapping_indices[i] > 0 else -1)
|
||||
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||
index_mapping_indices[i] = i
|
||||
lora_indices[i] = lora_idx
|
||||
|
||||
indices = torch.tensor([indices, lora_indices, embedding_indices],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
prompt_mapping = torch.tensor(prompt_mapping,
|
||||
device="cuda",
|
||||
dtype=torch.long)
|
||||
indices = torch.tensor(
|
||||
[index_mapping_indices, lora_indices, embedding_indices],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||
device="cuda",
|
||||
dtype=torch.long)
|
||||
embeddings_indices = torch.stack([
|
||||
indices[2] * extra_vocab_size,
|
||||
indices[2] * (vocab_size + extra_vocab_size)
|
||||
])
|
||||
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
||||
base_indices = indices[1]
|
||||
sampler_indices = prompt_mapping
|
||||
sampler_indices = prompt_mapping_tensor
|
||||
sampler_indices_padded = sampler_indices.clone()
|
||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||
sampler_indices_padded = (
|
||||
torch.arange(
|
||||
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
|
||||
(sampler_indices_padded * len(sampler_indices_padded)))
|
||||
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
|
||||
sampler_indices_padded.shape[-1],
|
||||
embeddings_indices.shape[-1])
|
||||
indices_len = [
|
||||
base_indices.shape[-1], sampler_indices.shape[-1],
|
||||
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
|
||||
]
|
||||
|
||||
return (base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, indices_len)
|
||||
@ -149,6 +151,7 @@ class LoRAModel:
|
||||
if module_name not in loras:
|
||||
lora_embeddings_tensor = None
|
||||
if embeddings:
|
||||
assert embedding_modules is not None
|
||||
embeddings_module = next(
|
||||
(k for k in embedding_modules if k in module_name),
|
||||
None)
|
||||
@ -171,6 +174,7 @@ class LoRAModel:
|
||||
else:
|
||||
loras[module_name].lora_b = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
assert embedding_padding_modules is not None
|
||||
if any(name in module_name
|
||||
for name in embedding_padding_modules
|
||||
) and target_embedding_padding is not None:
|
||||
@ -295,11 +299,10 @@ class LoRAModelManager:
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.offsets = []
|
||||
# 4 is the number of indicies tensors defined above
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices
|
||||
self.indices_len = [None] * 4
|
||||
self.indices_len: List[Optional[int]] = [None] * 4
|
||||
|
||||
self.model: nn.Module = model
|
||||
if hasattr(self.model, "supported_lora_modules"):
|
||||
@ -312,7 +315,7 @@ class LoRAModelManager:
|
||||
self._registered_loras: Dict[int, LoRAModel] = {}
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self._active_loras: Dict[int, None] = {}
|
||||
self._last_mapping = None
|
||||
self._last_mapping: Optional[LoRAMapping] = None
|
||||
self._create_lora_modules()
|
||||
self.model.lora_manager = self
|
||||
|
||||
@ -370,7 +373,7 @@ class LoRAModelManager:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _add_lora(self, lora: LoRAModel) -> bool:
|
||||
def _add_lora(self, lora: LoRAModel):
|
||||
self._create_merged_loras_inplace(lora)
|
||||
self._registered_loras[lora.id] = lora
|
||||
|
||||
@ -418,7 +421,7 @@ class LoRAModelManager:
|
||||
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
|
||||
return self._registered_loras.get(lora_id, None)
|
||||
|
||||
def remove_all_loras(self) -> bool:
|
||||
def remove_all_loras(self):
|
||||
"""Remove all LoRAModels from the manager."""
|
||||
self._registered_loras.clear()
|
||||
self.lora_index_to_id = [None] * self.lora_slots
|
||||
@ -467,6 +470,7 @@ class LoRAModelManager:
|
||||
continue
|
||||
parts = module_name.split(".")
|
||||
if module_name not in self.packed_modules:
|
||||
assert embedding_modules is not None
|
||||
if parts[-1] in embedding_modules:
|
||||
input_dim = (module.base_layer.org_vocab_size +
|
||||
self.lora_config.lora_extra_vocab_size if
|
||||
@ -500,7 +504,7 @@ class LoRAModelManager:
|
||||
else:
|
||||
parts = module_name.split(".")
|
||||
replacements = self.packed_modules_mapping[parts[-1]]
|
||||
subloras = []
|
||||
subloras: List[Optional["LoRALayerWeights"]] = []
|
||||
for i, r in enumerate(replacements):
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name + "." + r,
|
||||
@ -538,7 +542,7 @@ class LoRAModelManager:
|
||||
|
||||
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||
for module_name, new_module_names in self.packed_modules.items():
|
||||
replacement_loras = []
|
||||
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
||||
has_replacement = False
|
||||
for r in new_module_names:
|
||||
lora = lora_model.get_lora(r)
|
||||
@ -557,12 +561,12 @@ class LoRAModelManager:
|
||||
|
||||
class LoRALRUCache(LRUCache[LoRAModel]):
|
||||
|
||||
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
|
||||
None]):
|
||||
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
|
||||
bool]):
|
||||
super().__init__(capacity)
|
||||
self.deactivate_lora_fn = deactivate_lora_fn
|
||||
|
||||
def _on_remove(self, key: Hashable, value: LoRAModel):
|
||||
def _on_remove(self, key: int, value: LoRAModel):
|
||||
logger.debug(f"Removing LoRA. int id: {key}")
|
||||
self.deactivate_lora_fn(key)
|
||||
return super()._on_remove(key, value)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Any, Dict, List, Optional, Set, Type
|
||||
from typing import Any, Dict, List, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_active_loras(self, lora_requests: List[LoRARequest],
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
...
|
||||
|
||||
@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_loras(self) -> bool:
|
||||
def remove_all_loras(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
embedding_padding_modules: List[str],
|
||||
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
||||
):
|
||||
self._lora_manager: Optional[LoRAModelManager] = None
|
||||
self._lora_model_cls = lora_model_cls
|
||||
self.embedding_modules = embedding_modules
|
||||
self.embedding_padding_modules = embedding_padding_modules
|
||||
# Lazily initialized by create_lora_manager.
|
||||
self._lora_manager: LoRAModelManager
|
||||
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
||||
lora_config, device)
|
||||
|
||||
@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
lora_config=self.lora_config,
|
||||
lora_manager_cls=self._lora_manager_cls,
|
||||
)
|
||||
self._lora_manager: LoRAModelManager = lora_manager
|
||||
self._lora_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self._lora_manager.remove_lora(lora_id)
|
||||
|
||||
def remove_all_loras(self) -> bool:
|
||||
def remove_all_loras(self):
|
||||
self._lora_manager.remove_all_loras()
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
lora_config=self.lora_config,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
)
|
||||
self._lora_manager: LRUCacheLoRAModelManager = lora_manager
|
||||
self._lora_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
|
||||
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
||||
loras_map = {
|
||||
lora_request.lora_int_id: lora_request
|
||||
for lora_request in lora_requests if lora_request
|
||||
@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
if lora_request.lora_int_id not in self.list_loras():
|
||||
# Remove before we load the new lora to save memory
|
||||
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
|
||||
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
|
||||
self._lora_manager.remove_oldest_lora()
|
||||
lora = self._load_lora(lora_request)
|
||||
loaded = self._lora_manager.add_lora(lora)
|
||||
else:
|
||||
# If the lora is already loaded, just touch it to
|
||||
# update its position in the caches
|
||||
loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
|
||||
loaded = self._lora_manager.get_lora(
|
||||
lora_request.lora_int_id) is not None
|
||||
self._lora_manager.activate_lora(lora_request.lora_int_id)
|
||||
return loaded
|
||||
|
@ -928,10 +928,10 @@ class ModelRunner:
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
def remove_all_loras(self) -> bool:
|
||||
def remove_all_loras(self):
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_all_loras()
|
||||
self.lora_manager.remove_all_loras()
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user