fix baichuan for different position embedding for 7b and 13b models (#643)

This commit is contained in:
Song 2023-08-02 13:22:51 +08:00 committed by GitHub
parent d4c7755ca8
commit 64f23c2900
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 16 deletions

View File

@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"BaiChuanForCausalLM": BaiChuanForCausalLM,
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,

View File

@ -1,4 +1,4 @@
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
__all__ = [
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM",
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",

View File

@ -22,6 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import math
from typing import Dict, List, Optional, Tuple
import torch
@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BaiChuanMLP(nn.Module):
def __init__(
@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module):
self,
hidden_size: int,
num_heads: int,
position_embedding: str,
):
super().__init__()
self.hidden_size = hidden_size
@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module):
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim**-0.5
self.postion_embedding = position_embedding
# pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear(
@ -109,11 +136,23 @@ class BaiChuanAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
else:
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward(
self,
@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module):
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
if self.postion_embedding == "ALIBI":
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
else:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module):
class BaiChuanModel(nn.Module):
def __init__(self, config: BaiChuanConfig):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config)
BaiChuanDecoderLayer(config, position_embedding)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module):
return hidden_states
class BaiChuanForCausalLM(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding: str):
super().__init__()
self.config = config
self.model = BaiChuanModel(config)
self.model = BaiChuanModel(config, position_embedding)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module):
self._row_parallel_weights,
tp_rank,
)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
def __init__(self, config):
super().__init__(config, "ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
def __init__(self, config):
super().__init__(config, "ROPE")