[Bugfix] Fix potentially unsafe custom allreduce synchronization (#8558)
This commit is contained in:
parent
8ff7ced996
commit
cc4325b66a
@ -6,6 +6,7 @@
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
@ -23,17 +24,23 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
constexpr int kMaxBlocks = 64;
|
||||
// note: we don't want to use atomics for signals because peer atomics are no
|
||||
// supported on PCIe links
|
||||
constexpr int kMaxBlocks = 36;
|
||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||
// well-defined behavior.
|
||||
using FlagType = uint32_t;
|
||||
struct Signal {
|
||||
alignas(128) uint32_t start[kMaxBlocks][8];
|
||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||
alignas(128) FlagType self_counter[kMaxBlocks][8];
|
||||
// Two sets of peer counters are needed for two syncs. The reason is that
|
||||
// it's possible for peer GPU block to arrive at the second sync point while
|
||||
// the current GPU block haven't passed the first sync point. Thus, peer GPU
|
||||
// may write counter+1 while current GPU is busy waiting for counter. We use
|
||||
// alternating counter array to avoid this possibility.
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
|
||||
|
||||
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
|
||||
struct __align__(16) RankSignals { Signal* signals[8]; };
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
@ -123,47 +130,60 @@ DINLINE O downcast(array_t<float, O::size> val) {
|
||||
}
|
||||
}
|
||||
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
|
||||
int rank) {
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->end[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x]);
|
||||
}
|
||||
__syncthreads();
|
||||
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
||||
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final synchronization
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
|
||||
int rank) {
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
if constexpr (!final_sync) __threadfence_system();
|
||||
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
|
||||
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
// is_start: whether this is the very first synchronization barrier.
|
||||
// need_fence: whether a memory fence is needed. If true, a release-acquire
|
||||
// semantic is used to enforce memory access order before and after this
|
||||
// barrier.
|
||||
template <int ngpus, bool is_start, bool need_fence = false>
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
|
||||
int rank) {
|
||||
if constexpr (!is_start) __syncthreads();
|
||||
static_assert(
|
||||
!(is_start && need_fence)); // Start barrier shouldn't need fence.
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->start[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x]);
|
||||
// Increment the counter. Technically we only need one counter, but we use
|
||||
// multiple per block to eliminate the need to share the counter via smem.
|
||||
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
|
||||
// Write the expected counter value to peer and wait for correct value from
|
||||
// peer.
|
||||
auto peer_counter_ptr =
|
||||
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
|
||||
auto self_counter_ptr =
|
||||
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val);
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val);
|
||||
}
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
@ -178,33 +198,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
|
||||
volatile Signal* self_sg, T* __restrict__ result,
|
||||
int rank, int size) {
|
||||
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P* get_tmp_buf(volatile Signal* sg) {
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
|
||||
volatile Signal* self_sg, T* __restrict__ result,
|
||||
int rank, int size) {
|
||||
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
@ -222,12 +240,12 @@ __global__ void __launch_bounds__(512, 1)
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, self_sg, rank);
|
||||
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
@ -437,6 +455,8 @@ class CustomAllreduce {
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
// TODO(hanzhi713): Threshold is different for A100 and H100.
|
||||
// Add per device threshold.
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
|
@ -1,15 +1,15 @@
|
||||
/**
|
||||
* This is a standalone test for custom allreduce.
|
||||
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||
* export MPI_HOME=XXX
|
||||
* export MPI_HOME=xxx
|
||||
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
|
||||
*
|
||||
* Warning: this C++ test is not designed to be very readable and was used
|
||||
* during the rapid prototyping process.
|
||||
*
|
||||
* To run:
|
||||
* mpirun -np 8 ./custom_all_reduce_test
|
||||
* mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test
|
||||
*/
|
||||
#include <cuda.h>
|
||||
#include <curand_kernel.h>
|
||||
@ -302,15 +302,19 @@ int main(int argc, char** argv) {
|
||||
|
||||
bool performance_test = true;
|
||||
cudaProfilerStart();
|
||||
// for (int threads : {256, 512}) {
|
||||
// Uncomment to scan through different block size configs.
|
||||
// for (int threads : {256, 512, 1024}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
|
||||
// performance_test);
|
||||
// }
|
||||
// }
|
||||
// Scan through different sizes to test performance.
|
||||
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
|
||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
|
||||
}
|
||||
|
||||
cudaProfilerStop();
|
||||
MPICHECK(MPI_Finalize());
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user