Support bfloat16 data type (#54)

This commit is contained in:
Woosuk Kwon 2023-05-03 14:09:44 -07:00 committed by GitHub
parent 436e523bf1
commit e070829ae8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 455 additions and 53 deletions

View File

@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
# NOTE(woosuk): FlashAttention does not support float32.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type')
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')

View File

@ -17,6 +17,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
}

View File

@ -34,7 +34,9 @@ void silu_and_mul(
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"silu_and_mul_kernel",
[&] {

View File

@ -3,3 +3,7 @@
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#ifdef ENABLE_BF16
#include "dtype_bfloat16.cuh"
#endif // ENABLE_BF16

View File

@ -1,7 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "attention_dtypes.cuh"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include <algorithm>
@ -438,9 +438,13 @@ void single_query_cached_kv_attention(
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len) {
// TODO(woosuk): Support FP32 and BF16.
// TODO(woosuk): Support FP32.
if (query.dtype() == at::ScalarType::Half) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
#ifdef ENABLE_BF16
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
#endif
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}

View File

@ -1,6 +1,6 @@
#pragma once
#include "attention_dtypes.cuh"
#include "attention_dtypes.h"
#include <float.h>
#include <type_traits>

View File

@ -0,0 +1,361 @@
#pragma once
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <stdint.h>
namespace cacheflow {
// Define custom BF16 vector data types.
struct bf16_4_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
};
struct bf16_8_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
};
// BF16 vector types for Q, K, V.
template<>
struct Vec<__nv_bfloat16, 1> {
using Type = __nv_bfloat16;
};
template<>
struct Vec<__nv_bfloat16, 2> {
using Type = __nv_bfloat162;
};
template<>
struct Vec<__nv_bfloat16, 4> {
using Type = bf16_4_t;
};
template<>
struct Vec<__nv_bfloat16, 8> {
using Type = bf16_8_t;
};
// FP32 accumulator vector types corresponding to Vec.
template<>
struct FloatVec<__nv_bfloat16> {
using Type = float;
};
template<>
struct FloatVec<__nv_bfloat162> {
using Type = float2;
};
template<>
struct FloatVec<bf16_4_t> {
using Type = Float4_;
};
template<>
struct FloatVec<bf16_8_t> {
using Type = Float8_;
};
// Utility functions for type conversions.
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
return __bfloat1622float2(val);
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
return __bfloat162bfloat162(val);
}
// Vector addition.
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
return a + b;
}
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
return __hadd2(a, b);
}
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
float2 fa = bf1622float2(a);
return add(fa, fb);
}
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
// Vector multiplication.
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
return __hmul(a, b);
}
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
return __hmul2(a, b);
}
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template<>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return c;
}
template<>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return c;
}
template<>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return c;
}
template<>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return c;
}
template<>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = __bfloat162float(a);
float fb = __bfloat162float(b);
return fa * fb;
}
template<>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template<>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template<>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return fc;
}
template<>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return fc;
}
template<>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return fc;
}
template<>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return fc;
}
// Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
return __hfma2(a, b, c);
}
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
return __hfma2(bf162bf162(a), b, c);
}
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
bf16_4_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
bf16_8_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
return __bfloat162float(a) * __bfloat162float(b) + fc;
}
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return fma(fa, fb, fc);
}
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
return fma(bf162bf162(a), b, fc);
}
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
// Vector sum.
template<>
inline __device__ float sum(__nv_bfloat16 v) {
return __bfloat162float(v);
}
template<>
inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
template<>
inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y);
}
template<>
inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
// From float32 to bfloat16.
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
dst = __float2bfloat16(src);
}
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
dst = __float22bfloat162_rn(src);
}
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
}
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
}
} // namespace cacheflow

View File

@ -6,7 +6,7 @@
namespace cacheflow {
// Define FP32 vector data types.
// Define custom FP32 vector data types.
struct Float4_ {
float2 x;
float2 y;

View File

@ -14,14 +14,16 @@ void swap_blocks(
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) {
assert(src_device.index() == dst_device.index());
TORCH_CHECK(
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost;
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
memcpy_type = cudaMemcpyHostToDevice;
} else {
assert(false);
TORCH_CHECK(false, "Invalid device combination");
}
void *src_ptr = src.data_ptr();
@ -29,6 +31,7 @@ void swap_blocks(
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
@ -122,7 +125,9 @@ void copy_blocks(
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel(
}
}
} // namespace cacheflow
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,
x);
});
}
namespace cacheflow {
// Grid: (num_blocks, block_size).
template<typename scalar_t>
__global__ void gather_cached_kv_kernel(
@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized(
} // namespace cacheflow
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,
x);
});
}
void gather_cached_kv(
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
@ -354,7 +364,9 @@ void gather_cached_kv(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key.scalar_type(),
"gather_cached_kv_kernel_optimized",
[&] {

View File

@ -46,7 +46,9 @@ void rms_norm(
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"rms_norm_kernel",
[&] {

View File

@ -64,7 +64,9 @@ void rotary_embedding_neox(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(),
"rotary_embedding_neox",
[&] {

View File

@ -1,9 +1,22 @@
import setuptools
import torch
from torch.utils import cpp_extension
CXX_FLAGS = ['-g']
NVCC_FLAGS = ['-O2']
if not torch.cuda.is_available():
raise RuntimeError(
f'Cannot find CUDA at CUDA_HOME: {cpp_extension.CUDA_HOME}. '
'CUDA must be available in order to build the package.')
# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
# different compute capabilities.
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
# Enable bfloat16 support if the compute capability is >= 8.0.
if major >= 8:
NVCC_FLAGS.append('-DENABLE_BF16')
ext_modules = []
@ -23,7 +36,7 @@ attention_extension = cpp_extension.CUDAExtension(
)
ext_modules.append(attention_extension)
# Positional encodings.
# Positional encoding kernels.
positional_encoding_extension = cpp_extension.CUDAExtension(
name='cacheflow.pos_encoding_ops',
sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'],
@ -39,6 +52,7 @@ layernorm_extension = cpp_extension.CUDAExtension(
)
ext_modules.append(layernorm_extension)
# Activation kernels.
activation_extension = cpp_extension.CUDAExtension(
name='cacheflow.activation_ops',
sources=['csrc/activation.cpp', 'csrc/activation_kernels.cu'],