Support bfloat16 data type (#54)
This commit is contained in:
parent
436e523bf1
commit
e070829ae8
@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument('--use-np-cache', action='store_true',
|
||||
help='save a numpy copy of model weights for faster loading')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
|
||||
# NOTE(woosuk): FlashAttention does not support float32.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||
|
@ -17,6 +17,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
'float': torch.float,
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'bfloat16': torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
|
@ -34,7 +34,9 @@ void silu_and_mul(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
input.scalar_type(),
|
||||
"silu_and_mul_kernel",
|
||||
[&] {
|
||||
|
@ -3,3 +3,7 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float16.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
#include "dtype_bfloat16.cuh"
|
||||
#endif // ENABLE_BF16
|
@ -1,7 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "attention_dtypes.cuh"
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
@ -438,9 +438,13 @@ void single_query_cached_kv_attention(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len) {
|
||||
// TODO(woosuk): Support FP32 and BF16.
|
||||
// TODO(woosuk): Support FP32.
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
#ifdef ENABLE_BF16
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "attention_dtypes.cuh"
|
||||
#include "attention_dtypes.h"
|
||||
|
||||
#include <float.h>
|
||||
#include <type_traits>
|
||||
|
361
csrc/attention/dtype_bfloat16.cuh
Normal file
361
csrc/attention/dtype_bfloat16.cuh
Normal file
@ -0,0 +1,361 @@
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
// Define custom BF16 vector data types.
|
||||
struct bf16_4_t {
|
||||
__nv_bfloat162 x;
|
||||
__nv_bfloat162 y;
|
||||
};
|
||||
|
||||
struct bf16_8_t {
|
||||
__nv_bfloat162 x;
|
||||
__nv_bfloat162 y;
|
||||
__nv_bfloat162 z;
|
||||
__nv_bfloat162 w;
|
||||
};
|
||||
|
||||
// BF16 vector types for Q, K, V.
|
||||
template<>
|
||||
struct Vec<__nv_bfloat16, 1> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
template<>
|
||||
struct Vec<__nv_bfloat16, 2> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
template<>
|
||||
struct Vec<__nv_bfloat16, 4> {
|
||||
using Type = bf16_4_t;
|
||||
};
|
||||
template<>
|
||||
struct Vec<__nv_bfloat16, 8> {
|
||||
using Type = bf16_8_t;
|
||||
};
|
||||
|
||||
// FP32 accumulator vector types corresponding to Vec.
|
||||
template<>
|
||||
struct FloatVec<__nv_bfloat16> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<__nv_bfloat162> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<bf16_4_t> {
|
||||
using Type = Float4_;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<bf16_8_t> {
|
||||
using Type = Float8_;
|
||||
};
|
||||
|
||||
// Utility functions for type conversions.
|
||||
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
||||
return __bfloat1622float2(val);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
||||
return __bfloat162bfloat162(val);
|
||||
}
|
||||
|
||||
// Vector addition.
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
||||
bf16_4_t c;
|
||||
c.x = add(a.x, b.x);
|
||||
c.y = add(a.y, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
|
||||
bf16_8_t c;
|
||||
c.x = add(a.x, b.x);
|
||||
c.y = add(a.y, b.y);
|
||||
c.z = add(a.z, b.z);
|
||||
c.w = add(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
|
||||
float2 fa = bf1622float2(a);
|
||||
return add(fa, fb);
|
||||
}
|
||||
|
||||
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
|
||||
Float4_ fc;
|
||||
fc.x = add(a.x, fb.x);
|
||||
fc.y = add(a.y, fb.y);
|
||||
return fc;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
||||
Float8_ fc;
|
||||
fc.x = add(a.x, fb.x);
|
||||
fc.y = add(a.y, fb.y);
|
||||
fc.z = add(a.z, fb.z);
|
||||
fc.w = add(a.w, fb.w);
|
||||
return fc;
|
||||
}
|
||||
|
||||
// Vector multiplication.
|
||||
template<>
|
||||
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
return __hmul(a, b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
return __hmul2(a, b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
||||
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
||||
bf16_4_t c;
|
||||
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
bf16_4_t c;
|
||||
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
||||
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
||||
bf16_8_t c;
|
||||
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
||||
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
||||
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
bf16_8_t c;
|
||||
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
||||
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
||||
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
||||
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
float fa = __bfloat162float(a);
|
||||
float fb = __bfloat162float(b);
|
||||
return fa * fb;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
float2 fa = bf1622float2(a);
|
||||
float2 fb = bf1622float2(b);
|
||||
return mul<float2, float2, float2>(fa, fb);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
||||
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
||||
Float4_ fc;
|
||||
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
Float4_ fc;
|
||||
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
||||
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
||||
Float8_ fc;
|
||||
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
||||
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
||||
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
Float8_ fc;
|
||||
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
||||
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
||||
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
||||
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
||||
return fc;
|
||||
}
|
||||
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||
return __hfma2(a, b, c);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||
return __hfma2(bf162bf162(a), b, c);
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
||||
bf16_4_t d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
bf16_4_t d;
|
||||
d.x = fma(s, b.x, c.x);
|
||||
d.y = fma(s, b.y, c.y);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
|
||||
bf16_8_t d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
d.z = fma(a.z, b.z, c.z);
|
||||
d.w = fma(a.w, b.w, c.w);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
bf16_8_t d;
|
||||
d.x = fma(s, b.x, c.x);
|
||||
d.y = fma(s, b.y, c.y);
|
||||
d.z = fma(s, b.z, c.z);
|
||||
d.w = fma(s, b.w, c.w);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
|
||||
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
||||
}
|
||||
|
||||
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
|
||||
float2 fa = bf1622float2(a);
|
||||
float2 fb = bf1622float2(b);
|
||||
return fma(fa, fb, fc);
|
||||
}
|
||||
|
||||
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
|
||||
return fma(bf162bf162(a), b, fc);
|
||||
}
|
||||
|
||||
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
|
||||
Float4_ fd;
|
||||
fd.x = fma(a.x, b.x, fc.x);
|
||||
fd.y = fma(a.y, b.y, fc.y);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
Float4_ fd;
|
||||
fd.x = fma(s, b.x, fc.x);
|
||||
fd.y = fma(s, b.y, fc.y);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
|
||||
Float8_ fd;
|
||||
fd.x = fma(a.x, b.x, fc.x);
|
||||
fd.y = fma(a.y, b.y, fc.y);
|
||||
fd.z = fma(a.z, b.z, fc.z);
|
||||
fd.w = fma(a.w, b.w, fc.w);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
Float8_ fd;
|
||||
fd.x = fma(s, b.x, fc.x);
|
||||
fd.y = fma(s, b.y, fc.y);
|
||||
fd.z = fma(s, b.z, fc.z);
|
||||
fd.w = fma(s, b.w, fc.w);
|
||||
return fd;
|
||||
}
|
||||
|
||||
// Vector sum.
|
||||
template<>
|
||||
inline __device__ float sum(__nv_bfloat16 v) {
|
||||
return __bfloat162float(v);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(__nv_bfloat162 v) {
|
||||
float2 vf = bf1622float2(v);
|
||||
return vf.x + vf.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(bf16_4_t v) {
|
||||
return sum(v.x) + sum(v.y);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(bf16_8_t v) {
|
||||
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
||||
}
|
||||
|
||||
// From float32 to bfloat16.
|
||||
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
|
||||
dst = __float2bfloat16(src);
|
||||
}
|
||||
|
||||
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
|
||||
dst = __float22bfloat162_rn(src);
|
||||
}
|
||||
|
||||
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
|
||||
dst.x = __float22bfloat162_rn(src.x);
|
||||
dst.y = __float22bfloat162_rn(src.y);
|
||||
}
|
||||
|
||||
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
||||
dst.x = __float22bfloat162_rn(src.x);
|
||||
dst.y = __float22bfloat162_rn(src.y);
|
||||
dst.z = __float22bfloat162_rn(src.z);
|
||||
dst.w = __float22bfloat162_rn(src.w);
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
@ -6,7 +6,7 @@
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
// Define FP32 vector data types.
|
||||
// Define custom FP32 vector data types.
|
||||
struct Float4_ {
|
||||
float2 x;
|
||||
float2 y;
|
||||
|
@ -14,14 +14,16 @@ void swap_blocks(
|
||||
torch::Device dst_device = dst.device();
|
||||
cudaMemcpyKind memcpy_type;
|
||||
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||
assert(src_device.index() == dst_device.index());
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
memcpy_type = cudaMemcpyDeviceToDevice;
|
||||
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||
memcpy_type = cudaMemcpyDeviceToHost;
|
||||
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
||||
memcpy_type = cudaMemcpyHostToDevice;
|
||||
} else {
|
||||
assert(false);
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
@ -29,6 +31,7 @@ void swap_blocks(
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
int64_t dst_block_number = pair.second;
|
||||
@ -122,7 +125,9 @@ void copy_blocks(
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [num_tokens]
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
});
|
||||
}
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
// Grid: (num_blocks, block_size).
|
||||
template<typename scalar_t>
|
||||
__global__ void gather_cached_kv_kernel(
|
||||
@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized(
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [num_tokens]
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
|
||||
@ -354,7 +364,9 @@ void gather_cached_kv(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
key.scalar_type(),
|
||||
"gather_cached_kv_kernel_optimized",
|
||||
[&] {
|
||||
|
@ -46,7 +46,9 @@ void rms_norm(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
input.scalar_type(),
|
||||
"rms_norm_kernel",
|
||||
[&] {
|
||||
|
@ -64,7 +64,9 @@ void rotary_embedding_neox(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_neox",
|
||||
[&] {
|
||||
|
16
setup.py
16
setup.py
@ -1,9 +1,22 @@
|
||||
import setuptools
|
||||
import torch
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
CXX_FLAGS = ['-g']
|
||||
NVCC_FLAGS = ['-O2']
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
f'Cannot find CUDA at CUDA_HOME: {cpp_extension.CUDA_HOME}. '
|
||||
'CUDA must be available in order to build the package.')
|
||||
|
||||
# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
|
||||
# different compute capabilities.
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, minor = compute_capability
|
||||
# Enable bfloat16 support if the compute capability is >= 8.0.
|
||||
if major >= 8:
|
||||
NVCC_FLAGS.append('-DENABLE_BF16')
|
||||
|
||||
ext_modules = []
|
||||
|
||||
@ -23,7 +36,7 @@ attention_extension = cpp_extension.CUDAExtension(
|
||||
)
|
||||
ext_modules.append(attention_extension)
|
||||
|
||||
# Positional encodings.
|
||||
# Positional encoding kernels.
|
||||
positional_encoding_extension = cpp_extension.CUDAExtension(
|
||||
name='cacheflow.pos_encoding_ops',
|
||||
sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'],
|
||||
@ -39,6 +52,7 @@ layernorm_extension = cpp_extension.CUDAExtension(
|
||||
)
|
||||
ext_modules.append(layernorm_extension)
|
||||
|
||||
# Activation kernels.
|
||||
activation_extension = cpp_extension.CUDAExtension(
|
||||
name='cacheflow.activation_ops',
|
||||
sources=['csrc/activation.cpp', 'csrc/activation_kernels.cu'],
|
||||
|
Loading…
x
Reference in New Issue
Block a user