2023-05-07 16:30:43 -07:00
|
|
|
from typing import List
|
|
|
|
|
2023-02-16 07:47:03 +00:00
|
|
|
import setuptools
|
2023-05-03 14:09:44 -07:00
|
|
|
import torch
|
2023-05-07 16:30:43 -07:00
|
|
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
|
|
from torch.utils.cpp_extension import CUDA_HOME
|
|
|
|
|
2023-02-16 07:47:03 +00:00
|
|
|
|
2023-05-07 16:30:43 -07:00
|
|
|
# Build custom operators.
|
|
|
|
CXX_FLAGS = ["-g"]
|
|
|
|
# TODO(woosuk): Should we use -O3?
|
|
|
|
NVCC_FLAGS = ["-O2"]
|
2023-02-16 07:47:03 +00:00
|
|
|
|
2023-05-03 14:09:44 -07:00
|
|
|
if not torch.cuda.is_available():
|
|
|
|
raise RuntimeError(
|
2023-05-07 16:30:43 -07:00
|
|
|
f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. "
|
|
|
|
"CUDA must be available in order to build the package.")
|
2023-05-03 14:09:44 -07:00
|
|
|
|
|
|
|
# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
|
|
|
|
# different compute capabilities.
|
|
|
|
compute_capability = torch.cuda.get_device_capability()
|
|
|
|
major, minor = compute_capability
|
|
|
|
# Enable bfloat16 support if the compute capability is >= 8.0.
|
|
|
|
if major >= 8:
|
2023-05-07 16:30:43 -07:00
|
|
|
NVCC_FLAGS.append("-DENABLE_BF16")
|
2023-02-16 07:47:03 +00:00
|
|
|
|
|
|
|
ext_modules = []
|
|
|
|
|
|
|
|
# Cache operations.
|
2023-05-07 16:30:43 -07:00
|
|
|
cache_extension = CUDAExtension(
|
|
|
|
name="cacheflow.cache_ops",
|
|
|
|
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
|
|
|
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
2023-02-16 07:47:03 +00:00
|
|
|
)
|
|
|
|
ext_modules.append(cache_extension)
|
|
|
|
|
2023-03-01 15:02:19 -08:00
|
|
|
# Attention kernels.
|
2023-05-07 16:30:43 -07:00
|
|
|
attention_extension = CUDAExtension(
|
|
|
|
name="cacheflow.attention_ops",
|
|
|
|
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
|
|
|
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
2023-03-01 15:02:19 -08:00
|
|
|
)
|
|
|
|
ext_modules.append(attention_extension)
|
|
|
|
|
2023-05-03 14:09:44 -07:00
|
|
|
# Positional encoding kernels.
|
2023-05-07 16:30:43 -07:00
|
|
|
positional_encoding_extension = CUDAExtension(
|
|
|
|
name="cacheflow.pos_encoding_ops",
|
|
|
|
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
|
|
|
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
2023-03-30 11:04:21 -07:00
|
|
|
)
|
|
|
|
ext_modules.append(positional_encoding_extension)
|
|
|
|
|
2023-03-31 09:51:22 -07:00
|
|
|
# Layer normalization kernels.
|
2023-05-07 16:30:43 -07:00
|
|
|
layernorm_extension = CUDAExtension(
|
|
|
|
name="cacheflow.layernorm_ops",
|
|
|
|
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
|
|
|
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
2023-03-31 09:51:22 -07:00
|
|
|
)
|
|
|
|
ext_modules.append(layernorm_extension)
|
|
|
|
|
2023-05-03 14:09:44 -07:00
|
|
|
# Activation kernels.
|
2023-05-07 16:30:43 -07:00
|
|
|
activation_extension = CUDAExtension(
|
|
|
|
name="cacheflow.activation_ops",
|
|
|
|
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
|
|
|
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
2023-04-02 00:30:17 -07:00
|
|
|
)
|
|
|
|
ext_modules.append(activation_extension)
|
|
|
|
|
2023-05-07 16:30:43 -07:00
|
|
|
|
|
|
|
def get_requirements() -> List[str]:
|
|
|
|
"""Get Python package dependencies from requirements.txt."""
|
|
|
|
with open("requirements.txt") as f:
|
|
|
|
requirements = f.read().strip().split("\n")
|
|
|
|
return requirements
|
|
|
|
|
|
|
|
|
2023-02-16 07:47:03 +00:00
|
|
|
setuptools.setup(
|
2023-05-07 16:30:43 -07:00
|
|
|
name="cacheflow",
|
|
|
|
python_requires=">=3.8",
|
|
|
|
install_requires=get_requirements(),
|
2023-02-16 07:47:03 +00:00
|
|
|
ext_modules=ext_modules,
|
2023-05-07 16:30:43 -07:00
|
|
|
cmdclass={"build_ext": BuildExtension},
|
2023-02-16 07:47:03 +00:00
|
|
|
)
|