2024-12-18 09:57:16 -05:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "cutlass/cutlass.h"
|
|
|
|
#include <climits>
|
|
|
|
#include "cuda_runtime.h"
|
|
|
|
#include <iostream>
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Helper function for checking CUTLASS errors
|
|
|
|
*/
|
|
|
|
#define CUTLASS_CHECK(status) \
|
|
|
|
{ \
|
|
|
|
cutlass::Status error = status; \
|
|
|
|
TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
|
|
|
cutlassGetStatusString(error)); \
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Panic wrapper for unwinding CUDA runtime errors
|
|
|
|
*/
|
|
|
|
#define CUDA_CHECK(status) \
|
|
|
|
{ \
|
|
|
|
cudaError_t error = status; \
|
|
|
|
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
|
|
|
|
}
|
|
|
|
|
|
|
|
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
|
|
|
int max_shared_mem_per_block_opt_in = 0;
|
|
|
|
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
2025-01-20 06:58:01 +00:00
|
|
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
2024-12-18 09:57:16 -05:00
|
|
|
return max_shared_mem_per_block_opt_in;
|
|
|
|
}
|
|
|
|
|
|
|
|
int32_t get_sm_version_num();
|
2025-01-30 21:33:00 -05:00
|
|
|
|
|
|
|
/**
|
|
|
|
* A wrapper for a kernel that is used to guard against compilation on
|
|
|
|
* architectures that will never use the kernel. The purpose of this is to
|
|
|
|
* reduce the size of the compiled binary.
|
|
|
|
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
|
|
|
* into code that will be executed on the device where it is defined.
|
|
|
|
*/
|
|
|
|
template <typename Kernel>
|
|
|
|
struct enable_sm90_or_later : Kernel {
|
|
|
|
template <typename... Args>
|
|
|
|
CUTLASS_DEVICE void operator()(Args&&... args) {
|
|
|
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
|
|
|
Kernel::operator()(std::forward<Args>(args)...);
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
};
|