Swapnil Parekh 4d6ada947c
[CORE] Adding support for insertion of soft-tuned prompts (#4645)
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com>
Co-authored-by: Joe G <joseph.granados@h2o.ai>
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
2024-07-09 13:26:36 -07:00

80 lines
2.6 KiB
Python

from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import PromptAdapterConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@dataclass
class PromptAdapterMapping(AdapterMapping):
pass
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.emb_layer = self.base_layer
if 'LoRA' in base_layer.__class__.__name__:
self.emb_layer = self.base_layer.base_layer
def create_prompt_adapter_weights(
self, prompt_adapter_config: PromptAdapterConfig):
self.embeddings_tensors = torch.zeros(
(
prompt_adapter_config.max_prompt_adapters,
prompt_adapter_config.max_prompt_adapter_token,
self.emb_layer.embedding_dim,
),
dtype=self.emb_layer.weight.dtype,
device=self.emb_layer.weight.device,
)
self.adapter_lengths = torch.zeros(
prompt_adapter_config.max_prompt_adapters,
dtype=torch.long,
device=self.emb_layer.weight.device)
self.indices_gpu: torch.Tensor
self.embedding_indices_gpu: torch.Tensor
def reset_prompt_adapter(self, index: int):
self.embeddings_tensors[index] = 0
def set_prompt_adapter(
self,
index: int,
adapter_model: Optional[torch.Tensor],
):
self.reset_prompt_adapter(index)
if adapter_model is not None:
length = adapter_model.shape[0]
self.embeddings_tensors[index, :length] = adapter_model
self.adapter_lengths[index] = length
def set_mapping(
self,
prompt_indices: torch.Tensor,
prompt_embedding_indices: torch.Tensor,
):
self.indices_gpu = prompt_indices.to(
device=self.emb_layer.weight.device)
self.embedding_indices_gpu = prompt_embedding_indices.to(
device=self.emb_layer.weight.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = self.base_layer(x)
if self.embedding_indices_gpu.ndim > 1:
valid_mask = self.indices_gpu != -1
gathered_embeddings = self.embeddings_tensors[
self.embedding_indices_gpu[:, 0],
self.embedding_indices_gpu[:, 1]]
# Update hidden states
hidden_states[valid_mask] = gathered_embeddings
return hidden_states