Support Roberta embedding models (#9387)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Maximilien de Bayser 2024-11-14 18:23:29 -03:00 committed by GitHub
parent 1dbae0329c
commit 4a18fd14ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 202 additions and 14 deletions

View File

@ -98,6 +98,9 @@ void paged_attention_v1_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V1(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;

View File

@ -104,6 +104,9 @@ void paged_attention_v2_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V2(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;

View File

@ -385,6 +385,9 @@ void paged_attention_v1_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 32:
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
@ -702,6 +705,9 @@ void paged_attention_v2_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 32:
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;

View File

@ -4,12 +4,17 @@ import pytest
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform
MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main")
MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME",
"intfloat/multilingual-e5-large")
REVISION_ROBERTA = os.environ.get("REVISION", "main")
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@ -48,3 +53,42 @@ def test_model_loading_with_params(vllm_runner):
assert model._pooler.normalize
# assert output
assert output
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_roberta_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
model_config = model.model.llm_engine.model_config
model_tokenizer = model.model.llm_engine.tokenizer
# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
assert not model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config.pooling_norm
# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
assert not model_tokenizer.tokenizer_config["do_lower_case"]
model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize
# assert output
assert output

View File

@ -13,10 +13,12 @@ MODELS = [
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
"intfloat/multilingual-e5-large",
]
ENCODER_ONLY = [
"BAAI/bge-base-en-v1.5",
"intfloat/multilingual-e5-large",
]

View File

@ -10,7 +10,7 @@ class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
return [32, 64, 80, 96, 112, 128, 256]
@staticmethod
def get_kv_cache_shape(

View File

@ -34,7 +34,7 @@ class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256]
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(

View File

@ -5,7 +5,7 @@ from torch import nn
from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -305,14 +305,16 @@ class BertOutput(nn.Module):
class BertModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = BertEmbedding(config)
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
@ -382,13 +384,9 @@ class BertEmbeddingModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(vllm_config=vllm_config,
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
self._pooler = self._build_pooler(pooler_config)
def forward(
self,
@ -415,3 +413,16 @@ class BertEmbeddingModel(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=BertEmbedding)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return Pooler.from_config_with_defaults(pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)

View File

@ -94,6 +94,8 @@ _TEXT_GENERATION_MODELS = {
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"),

View File

@ -0,0 +1,117 @@
from typing import List, Optional
import torch
from torch import nn
from transformers import RobertaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.sequence import IntermediateTensors
class RobertaEmbedding(nn.Module):
def __init__(self, config: RobertaConfig):
super().__init__()
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size,
padding_idx=self.padding_idx)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)
# TODO: figure out if there is a better way
# to make to make position ids start at padding_idx + 1
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
position_ids += self.padding_idx + 1
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
# Token type embeddings. (TODO: move off hotpath?)
token_type_embeddings = self.token_type_embeddings(
torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device))
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Verify assumption that position are always a sequence from
# 0 to N. (Actually here we just check 0 and N to simplify).
# This is important to fix the position which are assumed to
# start from padding_idx + 1 instead of 0 in the Roberta models.
assert hasattr(attn_metadata, "seq_lens_tensor")
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
start_pos = torch.cat(
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
cumulative[:-1]))
assert len(torch.nonzero(positions[start_pos])) == 0
end_pos = cumulative - 1
last_tokens = attn_metadata.seq_lens_tensor - 1
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0
return super().forward(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)