2023-02-16 07:47:03 +00:00
|
|
|
import setuptools
|
|
|
|
from torch.utils import cpp_extension
|
|
|
|
|
|
|
|
CXX_FLAGS = ['-g']
|
|
|
|
NVCC_FLAGS = ['-O2']
|
|
|
|
|
|
|
|
|
|
|
|
ext_modules = []
|
|
|
|
|
|
|
|
# Cache operations.
|
|
|
|
cache_extension = cpp_extension.CUDAExtension(
|
2023-03-01 15:02:19 -08:00
|
|
|
name='cacheflow.cache_ops',
|
2023-02-16 20:05:45 +00:00
|
|
|
sources=['csrc/cache.cpp', 'csrc/cache_kernels.cu'],
|
2023-02-16 07:47:03 +00:00
|
|
|
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
|
|
|
)
|
|
|
|
ext_modules.append(cache_extension)
|
|
|
|
|
2023-03-01 15:02:19 -08:00
|
|
|
# Attention kernels.
|
|
|
|
attention_extension = cpp_extension.CUDAExtension(
|
|
|
|
name='cacheflow.attention_ops',
|
|
|
|
sources=['csrc/attention.cpp', 'csrc/attention_kernels.cu'],
|
|
|
|
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
|
|
|
)
|
|
|
|
ext_modules.append(attention_extension)
|
|
|
|
|
2023-03-30 11:04:21 -07:00
|
|
|
# Positional encodings.
|
|
|
|
positional_encoding_extension = cpp_extension.CUDAExtension(
|
|
|
|
name='cacheflow.pos_encoding_ops',
|
|
|
|
sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'],
|
|
|
|
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
|
|
|
)
|
|
|
|
ext_modules.append(positional_encoding_extension)
|
|
|
|
|
2023-03-31 09:51:22 -07:00
|
|
|
# Layer normalization kernels.
|
|
|
|
layernorm_extension = cpp_extension.CUDAExtension(
|
|
|
|
name='cacheflow.layernorm_ops',
|
|
|
|
sources=['csrc/layernorm.cpp', 'csrc/layernorm_kernels.cu'],
|
|
|
|
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
|
|
|
)
|
|
|
|
ext_modules.append(layernorm_extension)
|
|
|
|
|
2023-02-16 07:47:03 +00:00
|
|
|
setuptools.setup(
|
|
|
|
name='cacheflow',
|
|
|
|
ext_modules=ext_modules,
|
|
|
|
cmdclass={'build_ext': cpp_extension.BuildExtension},
|
|
|
|
)
|