[ROCm] Fix warp and lane calculation in blockReduceSum (#3321)
This commit is contained in:
parent
4c922709b6
commit
c9415c19d3
@ -29,12 +29,22 @@ __inline__ __device__ T warpReduceSum(T val) {
|
|||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
|
||||||
|
return warp_size - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
|
||||||
|
return 5 + (warp_size >> 6);
|
||||||
|
}
|
||||||
|
|
||||||
/* Calculate the sum of all elements in a block */
|
/* Calculate the sum of all elements in a block */
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__inline__ __device__ T blockReduceSum(T val) {
|
__inline__ __device__ T blockReduceSum(T val) {
|
||||||
static __shared__ T shared[WARP_SIZE];
|
static __shared__ T shared[WARP_SIZE];
|
||||||
int lane = threadIdx.x & 0x1f;
|
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
|
||||||
int wid = threadIdx.x >> 5;
|
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
|
||||||
|
int lane = threadIdx.x & LANE_MASK;
|
||||||
|
int wid = threadIdx.x >> WID_SHIFT;
|
||||||
|
|
||||||
val = warpReduceSum<T>(val);
|
val = warpReduceSum<T>(val);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user