From 8ce9c50d4034de3c557b520935fac1d6dac585a0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 2 Sep 2023 14:59:47 +0900 Subject: [PATCH] Avoid compiling kernels for double data type (#933) --- csrc/activation_kernels.cu | 10 ++++------ csrc/cache_kernels.cu | 14 +++++--------- csrc/dispatch_utils.h | 14 ++++++++++++++ csrc/layernorm_kernels.cu | 5 ++--- csrc/pos_encoding_kernels.cu | 6 +++--- 5 files changed, 28 insertions(+), 21 deletions(-) create mode 100644 csrc/dispatch_utils.h diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index fc1f086f..c6ae5db8 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,8 @@ #include #include +#include "dispatch_utils.h" + namespace vllm { template @@ -34,9 +36,7 @@ void silu_and_mul( dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "silu_and_mul_kernel", [&] { @@ -71,9 +71,7 @@ __global__ void activation_kernel( dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - AT_DISPATCH_FLOATING_TYPES_AND2( \ - at::ScalarType::Half, \ - at::ScalarType::BFloat16, \ + VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), \ "activation_kernel", \ [&] { \ diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 5e7b6be4..ddad2b5a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,8 @@ #include #include +#include "dispatch_utils.h" + #include #include #include @@ -125,9 +127,7 @@ void copy_blocks( dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), @@ -202,9 +202,7 @@ void reshape_and_cache( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), "reshape_and_cache_kernel", [&] { @@ -364,9 +362,7 @@ void gather_cached_kv( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), "gather_cached_kv_kernel_optimized", [&] { diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h new file mode 100644 index 00000000..7c0c49d3 --- /dev/null +++ b/csrc/dispatch_utils.h @@ -0,0 +1,14 @@ +/* + * Adapted from + * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h + */ +#include + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 73503c55..f932b9e2 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "dispatch_utils.h" #include "reduction_utils.cuh" namespace vllm { @@ -46,9 +47,7 @@ void rms_norm( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_kernel", [&] { diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 98939fc7..ced26ecb 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,8 @@ #include #include +#include "dispatch_utils.h" + namespace vllm { template @@ -83,9 +85,7 @@ void rotary_embedding_neox( dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), "rotary_embedding_neox", [&] {