Improve setup script & Add a guard for bfloat16 kernels (#130)
This commit is contained in:
parent
4a151dd453
commit
d721168449
@ -3,7 +3,4 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float16.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
#include "dtype_bfloat16.cuh"
|
||||
#endif // ENABLE_BF16
|
||||
|
@ -458,10 +458,8 @@ void single_query_cached_kv_attention(
|
||||
// TODO(woosuk): Support FP32.
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
#ifdef ENABLE_BF16
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
|
@ -78,20 +78,36 @@ struct FloatVec<bf16_8_t> {
|
||||
|
||||
// Utility functions for type conversions.
|
||||
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __bfloat1622float2(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __bfloat162bfloat162(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Vector addition.
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return a + b;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hadd2(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
||||
@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
||||
// Vector multiplication.
|
||||
template<>
|
||||
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hmul(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hmul2(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hfma2(a, b, c);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hfma2(bf162bf162(a), b, c);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
||||
@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
|
||||
}
|
||||
|
||||
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
dst = __float22bfloat162_rn(src);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
dst.x = __float22bfloat162_rn(src.x);
|
||||
dst.y = __float22bfloat162_rn(src.y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
dst.x = __float22bfloat162_rn(src.x);
|
||||
dst.y = __float22bfloat162_rn(src.y);
|
||||
dst.z = __float22bfloat162_rn(src.z);
|
||||
dst.w = __float22bfloat162_rn(src.w);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
57
setup.py
57
setup.py
@ -1,28 +1,63 @@
|
||||
from typing import List
|
||||
import subprocess
|
||||
from typing import List, Set
|
||||
|
||||
from packaging.version import parse, Version
|
||||
import setuptools
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
|
||||
# Build custom operators.
|
||||
CXX_FLAGS = ["-g"]
|
||||
# Compiler flags.
|
||||
CXX_FLAGS = ["-g", "-O2"]
|
||||
# TODO(woosuk): Should we use -O3?
|
||||
NVCC_FLAGS = ["-O2"]
|
||||
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. "
|
||||
"CUDA must be available in order to build the package.")
|
||||
|
||||
# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
|
||||
# different compute capabilities.
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, minor = compute_capability
|
||||
# Enable bfloat16 support if the compute capability is >= 8.0.
|
||||
if major >= 8:
|
||||
NVCC_FLAGS.append("-DENABLE_BF16")
|
||||
|
||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
"""Get the CUDA version from nvcc.
|
||||
|
||||
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
||||
"""
|
||||
nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
|
||||
universal_newlines=True)
|
||||
output = nvcc_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
|
||||
return nvcc_cuda_version
|
||||
|
||||
|
||||
# Collect the compute capabilities of all available GPUs.
|
||||
device_count = torch.cuda.device_count()
|
||||
compute_capabilities: Set[int] = set()
|
||||
for i in range(device_count):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
if major < 7:
|
||||
raise RuntimeError(
|
||||
"GPUs with compute capability less than 7.0 are not supported.")
|
||||
compute_capabilities.add(major * 10 + minor)
|
||||
# If no GPU is available, add all supported compute capabilities.
|
||||
if not compute_capabilities:
|
||||
compute_capabilities = {70, 75, 80, 86, 90}
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
||||
|
||||
# Validate the NVCC CUDA version.
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
|
||||
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
||||
|
||||
ext_modules = []
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user