diff --git a/csrc/layernorm.cpp b/csrc/layernorm.cpp index 749ca5f9..c341a709 100644 --- a/csrc/layernorm.cpp +++ b/csrc/layernorm.cpp @@ -6,9 +6,19 @@ void rms_norm( torch::Tensor& weight, float epsilon); +void fused_add_rms_norm( + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + float epsilon); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "rms_norm", &rms_norm, "Apply Root Mean Square (RMS) Normalization to the input tensor."); + m.def( + "fused_add_rms_norm", + &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); } diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index fe07c272..7434f4fd 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -34,6 +34,36 @@ __global__ void rms_norm_kernel( } } +// TODO: Further optimize this kernel. +template +__global__ void fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float) input[blockIdx.x * hidden_size + idx]; + x += (float) residual[blockIdx.x * hidden_size + idx]; + variance += x * x; + residual[blockIdx.x * hidden_size + idx] = (scalar_t) x; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float) residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + } +} + } // namespace vllm void rms_norm( @@ -60,3 +90,28 @@ void rms_norm( hidden_size); }); } + +void fused_add_rms_norm( + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "fused_add_rms_norm_kernel", + [&] { + vllm::fused_add_rms_norm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size); + }); +} diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 731bc7cb..275efa0b 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -1,4 +1,6 @@ """Custom normalization layers.""" +from typing import Optional, Tuple, Union + import torch import torch.nn as nn @@ -21,7 +23,19 @@ class RMSNorm(nn.Module): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + layernorm_ops.fused_add_rms_norm( + x, + residual, + self.weight.data, + self.variance_epsilon, + ) + return x, residual out = torch.empty_like(x) layernorm_ops.rms_norm( out, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 64bbd598..f86bd241 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -225,10 +225,15 @@ class BaiChuanDecoderLayer(nn.Module): kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -236,14 +241,12 @@ class BaiChuanDecoderLayer(nn.Module): input_metadata=input_metadata, cache_event=cache_event, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + return hidden_states, residual class BaiChuanModel(nn.Module): @@ -276,20 +279,22 @@ class BaiChuanModel(nn.Module): cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + residual = None for i in range(len(self.layers)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.layers[i] - hidden_states = layer( + hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, cache_event, + residual, ) - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index d90f8aae..4621be73 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -155,10 +155,15 @@ class InternLMDecoderLayer(nn.Module): kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -166,14 +171,12 @@ class InternLMDecoderLayer(nn.Module): input_metadata=input_metadata, cache_event=cache_event, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + return hidden_states, residual class InternLMModel(nn.Module): @@ -208,20 +211,22 @@ class InternLMModel(nn.Module): cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + residual = None for i in range(len(self.layers)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.layers[i] - hidden_states = layer( + hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, cache_event, + residual, ) - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9381a239..2e02ef15 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -197,10 +197,15 @@ class LlamaDecoderLayer(nn.Module): kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -208,14 +213,12 @@ class LlamaDecoderLayer(nn.Module): input_metadata=input_metadata, cache_event=cache_event, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + return hidden_states, residual class LlamaModel(nn.Module): @@ -248,20 +251,22 @@ class LlamaModel(nn.Module): cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + residual = None for i in range(len(self.layers)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.layers[i] - hidden_states = layer( + hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, cache_event, + residual, ) - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index f9b9120a..af3199af 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -191,10 +191,15 @@ class MistralDecoderLayer(nn.Module): kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -202,14 +207,12 @@ class MistralDecoderLayer(nn.Module): input_metadata=input_metadata, cache_event=cache_event, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + return hidden_states, residual class MistralModel(nn.Module): @@ -243,20 +246,22 @@ class MistralModel(nn.Module): cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + residual = None for i in range(len(self.layers)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.layers[i] - hidden_states = layer( + hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, cache_event, + residual, ) - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 45710edc..18e15036 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -159,10 +159,14 @@ class QWenBlock(nn.Module): kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - residual = hidden_states - hidden_states = self.ln_1(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.attn( positions=positions, hidden_states=hidden_states, @@ -170,14 +174,11 @@ class QWenBlock(nn.Module): input_metadata=input_metadata, cache_event=cache_event, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.ln_2(hidden_states) + hidden_states, residual = self.ln_2(hidden_states, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + return hidden_states, residual class QWenModel(nn.Module): @@ -210,20 +211,22 @@ class QWenModel(nn.Module): cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.wte(input_ids) + residual = None for i in range(len(self.h)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.h[i] - hidden_states = layer( + hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, cache_event, + residual, ) - hidden_states = self.ln_f(hidden_states) + hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index 204c33ed..91e07377 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -195,10 +195,14 @@ class YiDecoderLayer(nn.Module): kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - residual = hidden_states - hidden_states = self.ln1(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.ln1(hidden_states) + else: + hidden_states, residual = self.ln1(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -206,14 +210,11 @@ class YiDecoderLayer(nn.Module): input_metadata=input_metadata, cache_event=cache_event, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.ln2(hidden_states) + hidden_states, residual = self.ln2(hidden_states, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + return hidden_states, residual class YiModel(nn.Module): @@ -246,20 +247,22 @@ class YiModel(nn.Module): cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + residual = None for i in range(len(self.layers)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.layers[i] - hidden_states = layer( + hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, cache_event, + residual, ) - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states