
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com> Co-authored-by: Joe G <joseph.granados@h2o.ai> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.adapter_commons.layers import AdapterMapping
|
|
from vllm.config import PromptAdapterConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding)
|
|
|
|
|
|
@dataclass
|
|
class PromptAdapterMapping(AdapterMapping):
|
|
pass
|
|
|
|
|
|
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
|
|
|
|
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
|
super().__init__()
|
|
self.base_layer = base_layer
|
|
self.emb_layer = self.base_layer
|
|
if 'LoRA' in base_layer.__class__.__name__:
|
|
self.emb_layer = self.base_layer.base_layer
|
|
|
|
def create_prompt_adapter_weights(
|
|
self, prompt_adapter_config: PromptAdapterConfig):
|
|
self.embeddings_tensors = torch.zeros(
|
|
(
|
|
prompt_adapter_config.max_prompt_adapters,
|
|
prompt_adapter_config.max_prompt_adapter_token,
|
|
self.emb_layer.embedding_dim,
|
|
),
|
|
dtype=self.emb_layer.weight.dtype,
|
|
device=self.emb_layer.weight.device,
|
|
)
|
|
self.adapter_lengths = torch.zeros(
|
|
prompt_adapter_config.max_prompt_adapters,
|
|
dtype=torch.long,
|
|
device=self.emb_layer.weight.device)
|
|
|
|
self.indices_gpu: torch.Tensor
|
|
self.embedding_indices_gpu: torch.Tensor
|
|
|
|
def reset_prompt_adapter(self, index: int):
|
|
self.embeddings_tensors[index] = 0
|
|
|
|
def set_prompt_adapter(
|
|
self,
|
|
index: int,
|
|
adapter_model: Optional[torch.Tensor],
|
|
):
|
|
self.reset_prompt_adapter(index)
|
|
if adapter_model is not None:
|
|
length = adapter_model.shape[0]
|
|
self.embeddings_tensors[index, :length] = adapter_model
|
|
self.adapter_lengths[index] = length
|
|
|
|
def set_mapping(
|
|
self,
|
|
prompt_indices: torch.Tensor,
|
|
prompt_embedding_indices: torch.Tensor,
|
|
):
|
|
self.indices_gpu = prompt_indices.to(
|
|
device=self.emb_layer.weight.device)
|
|
self.embedding_indices_gpu = prompt_embedding_indices.to(
|
|
device=self.emb_layer.weight.device)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.base_layer(x)
|
|
if self.embedding_indices_gpu.ndim > 1:
|
|
valid_mask = self.indices_gpu != -1
|
|
gathered_embeddings = self.embeddings_tensors[
|
|
self.embedding_indices_gpu[:, 0],
|
|
self.embedding_indices_gpu[:, 1]]
|
|
|
|
# Update hidden states
|
|
hidden_states[valid_mask] = gathered_embeddings
|
|
return hidden_states |