Support Roberta embedding models (#9387)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
parent
1dbae0329c
commit
4a18fd14ba
@ -98,6 +98,9 @@ void paged_attention_v1_launcher(
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
// head sizes that we use in the model. However, we can easily extend this
|
||||
// to support any head size which is a multiple of 16.
|
||||
case 32:
|
||||
LAUNCH_PAGED_ATTENTION_V1(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_PAGED_ATTENTION_V1(64);
|
||||
break;
|
||||
|
@ -104,6 +104,9 @@ void paged_attention_v2_launcher(
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
// head sizes that we use in the model. However, we can easily extend this
|
||||
// to support any head size which is a multiple of 16.
|
||||
case 32:
|
||||
LAUNCH_PAGED_ATTENTION_V2(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_PAGED_ATTENTION_V2(64);
|
||||
break;
|
||||
|
@ -385,6 +385,9 @@ void paged_attention_v1_impl_launcher(
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 32:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
@ -702,6 +705,9 @@ void paged_attention_v2_impl_launcher(
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 32:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
|
@ -4,12 +4,17 @@ import pytest
|
||||
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel
|
||||
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MAX_MODEL_LEN = 128
|
||||
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
|
||||
REVISION = os.environ.get("REVISION", "main")
|
||||
|
||||
MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME",
|
||||
"intfloat/multilingual-e5-large")
|
||||
REVISION_ROBERTA = os.environ.get("REVISION", "main")
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
@ -48,3 +53,42 @@ def test_model_loading_with_params(vllm_runner):
|
||||
assert model._pooler.normalize
|
||||
# assert output
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_roberta_model_loading_with_params(vllm_runner):
|
||||
"""
|
||||
Test parameter weight loading with tp>1.
|
||||
"""
|
||||
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
|
||||
revision=REVISION_ROBERTA,
|
||||
dtype="float16",
|
||||
max_model_len=MAX_MODEL_LEN) as model:
|
||||
output = model.encode("Write a short story about a robot that"
|
||||
" dreams for the first time.\n")
|
||||
|
||||
model_config = model.model.llm_engine.model_config
|
||||
|
||||
model_tokenizer = model.model.llm_engine.tokenizer
|
||||
|
||||
# asserts on the bert model config file
|
||||
assert model_config.encoder_config["max_seq_length"] == 512
|
||||
assert not model_config.encoder_config["do_lower_case"]
|
||||
|
||||
# asserts on the pooling config files
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
|
||||
assert model_config.pooler_config.pooling_norm
|
||||
|
||||
# asserts on the tokenizer loaded
|
||||
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
|
||||
assert not model_tokenizer.tokenizer_config["do_lower_case"]
|
||||
|
||||
model = model.model.llm_engine.model_executor\
|
||||
.driver_worker.model_runner.model
|
||||
assert isinstance(model, RobertaEmbeddingModel)
|
||||
assert model._pooler.pooling_type == PoolingType.MEAN
|
||||
assert model._pooler.normalize
|
||||
|
||||
# assert output
|
||||
assert output
|
||||
|
@ -13,10 +13,12 @@ MODELS = [
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
"BAAI/bge-multilingual-gemma2",
|
||||
"intfloat/multilingual-e5-large",
|
||||
]
|
||||
|
||||
ENCODER_ONLY = [
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
"intfloat/multilingual-e5-large",
|
||||
]
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [64, 80, 96, 112, 128, 256]
|
||||
return [32, 64, 80, 96, 112, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
@ -34,7 +34,7 @@ class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [64, 80, 96, 112, 120, 128, 192, 256]
|
||||
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
@ -5,7 +5,7 @@ from torch import nn
|
||||
from transformers import BertConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -305,14 +305,16 @@ class BertOutput(nn.Module):
|
||||
|
||||
class BertModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type = BertEmbedding):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.embeddings = BertEmbedding(config)
|
||||
self.embeddings = embedding_class(config)
|
||||
self.encoder = BertEncoder(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
@ -382,13 +384,9 @@ class BertEmbeddingModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.model = BertModel(vllm_config=vllm_config,
|
||||
self.model = self._build_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.CLS,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self._pooler = self._build_pooler(pooler_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -415,3 +413,16 @@ class BertEmbeddingModel(nn.Module):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(weights)
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=BertEmbedding)
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return Pooler.from_config_with_defaults(pooler_config,
|
||||
pooling_type=PoolingType.CLS,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
|
@ -94,6 +94,8 @@ _TEXT_GENERATION_MODELS = {
|
||||
_EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
|
||||
|
117
vllm/model_executor/models/roberta.py
Normal file
117
vllm/model_executor/models/roberta.py
Normal file
@ -0,0 +1,117 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, config: RobertaConfig):
|
||||
super().__init__()
|
||||
self.size = config.hidden_size
|
||||
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
|
||||
config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.position_ids = nn.Parameter(
|
||||
torch.empty((1, config.max_position_embeddings)), )
|
||||
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError("Only 'absolute' position_embedding_type" +
|
||||
" is supported")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
|
||||
# Input embeddings.
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# TODO: figure out if there is a better way
|
||||
# to make to make position ids start at padding_idx + 1
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
position_ids += self.padding_idx + 1
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
# Token type embeddings. (TODO: move off hotpath?)
|
||||
token_type_embeddings = self.token_type_embeddings(
|
||||
torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device))
|
||||
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of BertModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=RobertaEmbedding)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Verify assumption that position are always a sequence from
|
||||
# 0 to N. (Actually here we just check 0 and N to simplify).
|
||||
# This is important to fix the position which are assumed to
|
||||
# start from padding_idx + 1 instead of 0 in the Roberta models.
|
||||
assert hasattr(attn_metadata, "seq_lens_tensor")
|
||||
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
|
||||
start_pos = torch.cat(
|
||||
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
|
||||
cumulative[:-1]))
|
||||
assert len(torch.nonzero(positions[start_pos])) == 0
|
||||
end_pos = cumulative - 1
|
||||
last_tokens = attn_metadata.seq_lens_tensor - 1
|
||||
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0
|
||||
|
||||
return super().forward(input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
Loading…
x
Reference in New Issue
Block a user