fix baichuan for different position embedding for 7b and 13b models (#643)
This commit is contained in:
parent
d4c7755ca8
commit
64f23c2900
@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
"BaiChuanForCausalLM": BaiChuanForCausalLM,
|
||||
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
||||
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
||||
"BloomForCausalLM": BloomForCausalLM,
|
||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM
|
||||
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
|
||||
from vllm.model_executor.models.bloom import BloomForCausalLM
|
||||
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
||||
@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
__all__ = [
|
||||
"BaiChuanForCausalLM",
|
||||
"BaichuanForCausalLM",
|
||||
"BloomForCausalLM",
|
||||
"GPT2LMHeadModel",
|
||||
"GPTBigCodeForCausalLM",
|
||||
|
@ -22,6 +22,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(start=1,
|
||||
end=1 + 2 * num_remaining_heads,
|
||||
step=2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
class BaiChuanMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
position_embedding: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module):
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.postion_embedding = position_embedding
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.W_pack = ColumnParallelLinear(
|
||||
@ -109,11 +136,23 @@ class BaiChuanAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||
scaling, alibi_slopes)
|
||||
else:
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module):
|
||||
qkv, _ = self.W_pack(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
if self.postion_embedding == "ALIBI":
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
else:
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BaiChuanDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig):
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = BaiChuanAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
position_embedding=position_embedding,
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
|
||||
class BaiChuanModel(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig):
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module):
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
BaiChuanDecoderLayer(config)
|
||||
BaiChuanDecoderLayer(config, position_embedding)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(nn.Module):
|
||||
class BaiChuanBaseForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding: str):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = BaiChuanModel(config)
|
||||
self.model = BaiChuanModel(config, position_embedding)
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module):
|
||||
self._row_parallel_weights,
|
||||
tp_rank,
|
||||
)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ALIBI")
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ROPE")
|
||||
|
Loading…
x
Reference in New Issue
Block a user