781 lines
26 KiB
C++
781 lines
26 KiB
C++
#include "cpu/cpu_types.hpp"
|
|
|
|
#include <fcntl.h>
|
|
#include <sys/mman.h>
|
|
#include <sys/stat.h>
|
|
#include <unistd.h>
|
|
|
|
namespace {
|
|
#define MAX_SHM_RANK_NUM 8
|
|
#define MAX_THREAD_NUM 12
|
|
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
|
#define MIN_THREAD_PROCESS_SIZE (8 * 1024)
|
|
#define MAX_P2P_SEND_TENSOR_NUM 8
|
|
|
|
template <typename scalar_t>
|
|
struct KernelVecType {
|
|
using scalar_vec_t = void;
|
|
};
|
|
|
|
template <>
|
|
struct KernelVecType<float> {
|
|
using scalar_vec_t = vec_op::FP32Vec16;
|
|
};
|
|
|
|
template <>
|
|
struct KernelVecType<c10::BFloat16> {
|
|
using scalar_vec_t = vec_op::BF16Vec16;
|
|
};
|
|
|
|
template <>
|
|
struct KernelVecType<c10::Half> {
|
|
using scalar_vec_t = vec_op::FP16Vec16;
|
|
};
|
|
|
|
enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE };
|
|
|
|
struct ThreadSHMContext {
|
|
volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM];
|
|
int thread_id;
|
|
int thread_num;
|
|
int rank;
|
|
int group_size;
|
|
size_t _spinning_count;
|
|
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
|
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
|
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
|
|
|
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
|
const int group_size, void* thread_shm_ptr)
|
|
: thread_id(thread_id),
|
|
thread_num(thread_num),
|
|
rank(rank),
|
|
group_size(group_size),
|
|
_spinning_count(0) {
|
|
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
|
|
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
|
TORCH_CHECK((size_t)this % 64 == 0);
|
|
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
|
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
|
shm_contexts[i] = nullptr;
|
|
thread_shm_ptrs[i] = nullptr;
|
|
swizzled_ranks[i] = (i + rank) % group_size;
|
|
thread_stats[i] = ThreadSHMStat::DONE;
|
|
}
|
|
set_context(rank, this, thread_shm_ptr);
|
|
}
|
|
|
|
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
|
|
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
|
|
TORCH_CHECK(ptr);
|
|
TORCH_CHECK(thread_shm_ptr);
|
|
TORCH_CHECK_EQ(ptr->thread_num, thread_num);
|
|
TORCH_CHECK_EQ(ptr->thread_id, thread_id);
|
|
shm_contexts[rank] = ptr;
|
|
thread_shm_ptrs[rank] = thread_shm_ptr;
|
|
}
|
|
|
|
template <typename T>
|
|
T* get_thread_shm_ptr(int rank) {
|
|
return reinterpret_cast<T*>(thread_shm_ptrs[rank]);
|
|
}
|
|
|
|
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
|
|
|
void wait_for_all(ThreadSHMStat prev_stat) {
|
|
for (int idx = 0; idx < group_size; ++idx) {
|
|
int rank = get_swizzled_rank(idx);
|
|
while (thread_stats[rank] == prev_stat) {
|
|
++_spinning_count;
|
|
_mm_pause();
|
|
}
|
|
}
|
|
vec_op::mem_barrier();
|
|
}
|
|
|
|
void wait_for_one(int rank, ThreadSHMStat prev_stat) {
|
|
while (thread_stats[rank] == prev_stat) {
|
|
++_spinning_count;
|
|
_mm_pause();
|
|
}
|
|
vec_op::mem_barrier();
|
|
}
|
|
|
|
void set_thread_stat(ThreadSHMStat stat) {
|
|
for (int idx = 0; idx < group_size; ++idx) {
|
|
int rank = get_swizzled_rank(idx);
|
|
shm_contexts[rank]->thread_stats[this->rank] = stat;
|
|
}
|
|
}
|
|
|
|
void set_thread_stat(int target_rank, ThreadSHMStat stat) {
|
|
for (int idx = 0; idx < group_size; ++idx) {
|
|
int rank = get_swizzled_rank(idx);
|
|
shm_contexts[rank]->thread_stats[target_rank] = stat;
|
|
}
|
|
}
|
|
|
|
// barrier for all ranks in the group, used for all2all ops
|
|
// DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ...
|
|
void barrier(ThreadSHMStat next_stat) {
|
|
if (next_stat == ThreadSHMStat::THREAD_READY) {
|
|
set_thread_stat(ThreadSHMStat::THREAD_READY);
|
|
wait_for_all(ThreadSHMStat::DONE);
|
|
} else if (next_stat == ThreadSHMStat::SHM_DATA_READY) {
|
|
set_thread_stat(ThreadSHMStat::SHM_DATA_READY);
|
|
wait_for_all(ThreadSHMStat::THREAD_READY);
|
|
} else if (next_stat == ThreadSHMStat::DONE) {
|
|
set_thread_stat(ThreadSHMStat::DONE);
|
|
wait_for_all(ThreadSHMStat::SHM_DATA_READY);
|
|
} else {
|
|
TORCH_CHECK(false, "Invalid next_stat to barrier.");
|
|
}
|
|
}
|
|
|
|
std::string to_string() const {
|
|
std::stringstream ss;
|
|
ss << "SHMContext:";
|
|
ss << "\nrank: " << rank;
|
|
ss << "\ngroup_size: " << group_size;
|
|
ss << "\nthread_num: " << thread_num;
|
|
ss << "\nthread_id: " << thread_id;
|
|
|
|
ss << "\nshm_ctx_stat_loop_seq: [";
|
|
for (int i = 0; i < group_size; ++i) {
|
|
ss << swizzled_ranks[i] << ", ";
|
|
}
|
|
ss << "]";
|
|
|
|
ss << "\nshm_contexts: [";
|
|
for (int i = 0; i < group_size; ++i) {
|
|
if (shm_contexts[i]) {
|
|
ss << shm_contexts[i]->rank << ", ";
|
|
}
|
|
}
|
|
ss << "]";
|
|
|
|
return ss.str();
|
|
}
|
|
};
|
|
|
|
class SHMManager {
|
|
public:
|
|
explicit SHMManager(const std::string& name, const int rank,
|
|
const int group_size)
|
|
: _rank(rank),
|
|
_group_size(group_size),
|
|
_thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)),
|
|
_shm_names({""}),
|
|
_shared_mem_ptrs({nullptr}),
|
|
_shm_ctx(nullptr) {
|
|
_shm_names[rank] = get_shm_name(name, rank);
|
|
_shared_mem_ptrs[rank] = init_shm(rank);
|
|
_shm_ctx = reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank]);
|
|
|
|
for (int i = 0; i < _thread_num; ++i) {
|
|
ThreadSHMContext* ctx = new (_shm_ctx + i)
|
|
ThreadSHMContext(i, _thread_num, _rank, _group_size,
|
|
compute_thread_shm_ptr(_shm_ctx, i));
|
|
}
|
|
}
|
|
|
|
void join(const std::string& name) {
|
|
for (int rank_idx = 0; rank_idx < _group_size; ++rank_idx) {
|
|
if (rank_idx != _rank) {
|
|
TORCH_CHECK(_shm_names[rank_idx].empty());
|
|
TORCH_CHECK(_shared_mem_ptrs[rank_idx] == nullptr);
|
|
_shm_names[rank_idx] = get_shm_name(name, rank_idx);
|
|
_shared_mem_ptrs[rank_idx] = init_shm(rank_idx);
|
|
ThreadSHMContext* target_ctx =
|
|
reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank_idx]);
|
|
for (int thread_idx = 0; thread_idx < _thread_num; ++thread_idx) {
|
|
_shm_ctx[thread_idx].set_context(
|
|
rank_idx, target_ctx + thread_idx,
|
|
compute_thread_shm_ptr(target_ctx, thread_idx));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
~SHMManager() { destroy_shm(); }
|
|
|
|
ThreadSHMContext* get_shm_ctx() const { return _shm_ctx; }
|
|
|
|
static std::string get_shm_name(const std::string& name, int rank) {
|
|
return name + "_" + std::to_string(rank);
|
|
}
|
|
|
|
static int64_t create_singleton_instance(const std::string& name,
|
|
const int group_size,
|
|
const int rank) {
|
|
std::lock_guard<std::mutex> guard(SingletonInstancesLock);
|
|
SingletonInstances.emplace_back(
|
|
std::make_unique<SHMManager>(name, rank, group_size));
|
|
return static_cast<int64_t>(SingletonInstances.size() - 1);
|
|
}
|
|
|
|
static SHMManager* get_singleton_instance(int64_t handle) {
|
|
return SingletonInstances[handle].get();
|
|
}
|
|
|
|
protected:
|
|
static std::vector<std::unique_ptr<SHMManager>> SingletonInstances;
|
|
static std::mutex SingletonInstancesLock;
|
|
|
|
private:
|
|
static size_t round_to_alignment(size_t num) {
|
|
return ((num + 63) / 64) * 64;
|
|
}
|
|
|
|
int8_t* compute_thread_shm_ptr(ThreadSHMContext* ctx, int thread_id) {
|
|
int8_t* thread_shm_ptr =
|
|
reinterpret_cast<int8_t*>(ctx) +
|
|
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
|
|
return thread_shm_ptr +
|
|
thread_id * round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES);
|
|
}
|
|
|
|
size_t compute_shm_size() {
|
|
const size_t rounded_rank_buffer_size =
|
|
round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES) * _thread_num;
|
|
const size_t rounded_thread_shm_ctx_size =
|
|
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
|
|
const size_t shm_size =
|
|
rounded_thread_shm_ctx_size + rounded_rank_buffer_size;
|
|
return shm_size;
|
|
}
|
|
|
|
void* init_shm(int target_rank) {
|
|
const std::string& shm_name = _shm_names[target_rank];
|
|
const int local_rank = _rank;
|
|
const size_t shm_size = compute_shm_size();
|
|
|
|
int fd = -1;
|
|
if (local_rank == target_rank) {
|
|
fd = shm_open(shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR,
|
|
S_IRUSR | S_IWUSR);
|
|
|
|
if (fd == -1)
|
|
TORCH_CHECK(false, "create shm in SHMManager failed. errno: " +
|
|
std::to_string(errno));
|
|
|
|
if (ftruncate(fd, shm_size) == -1)
|
|
TORCH_CHECK(false, "ftruncate in SHMManager failed. errno: " +
|
|
std::to_string(errno));
|
|
} else {
|
|
fd = shm_open(shm_name.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
|
|
|
|
if (fd == -1)
|
|
TORCH_CHECK(false, "open shm in SHMManager failed. errno: " +
|
|
std::to_string(errno));
|
|
}
|
|
|
|
void* shm_ptr = mmap(nullptr, shm_size, PROT_READ | PROT_WRITE,
|
|
MAP_SHARED | MAP_POPULATE, fd, 0);
|
|
|
|
if (shm_ptr == MAP_FAILED) {
|
|
TORCH_CHECK(false,
|
|
"mmap in SHMManager failed. errno: " + std::to_string(errno));
|
|
}
|
|
|
|
if (close(fd) != 0) {
|
|
TORCH_CHECK(
|
|
false, "close in SHMManager failed. errno: " + std::to_string(errno));
|
|
}
|
|
|
|
TORCH_CHECK((size_t)shm_ptr % 64 == 0);
|
|
|
|
return shm_ptr;
|
|
}
|
|
|
|
void destroy_shm() {
|
|
std::stringstream ss;
|
|
ss << "local rank " << _rank << ": [";
|
|
for (int thread_id = 0; thread_id < _thread_num; ++thread_id) {
|
|
ss << _shm_ctx[thread_id]._spinning_count << ", ";
|
|
}
|
|
ss << "]\n";
|
|
|
|
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
|
if (_shared_mem_ptrs[i] != nullptr) {
|
|
munmap(_shared_mem_ptrs[i], compute_shm_size());
|
|
}
|
|
|
|
if (!_shm_names[i].empty()) {
|
|
shm_unlink(_shm_names[i].c_str());
|
|
}
|
|
}
|
|
}
|
|
|
|
int _rank;
|
|
int _group_size;
|
|
int _thread_num;
|
|
std::array<std::string, MAX_SHM_RANK_NUM> _shm_names;
|
|
std::array<void*, MAX_SHM_RANK_NUM> _shared_mem_ptrs;
|
|
ThreadSHMContext* _shm_ctx;
|
|
};
|
|
|
|
namespace shm_cc_ops {
|
|
template <typename scalar_t, typename F>
|
|
void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
|
int thread_num = ctx->thread_num;
|
|
int64_t total_bytes = elem_num * sizeof(scalar_t);
|
|
int64_t total_units_num =
|
|
(total_bytes + MIN_THREAD_PROCESS_SIZE - 1) / MIN_THREAD_PROCESS_SIZE;
|
|
int64_t per_thread_units_num =
|
|
(total_units_num + thread_num - 1) / thread_num;
|
|
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
|
|
int64_t max_per_thread_iteration_elem_num =
|
|
PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t);
|
|
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
|
|
|
|
#pragma omp parallel for schedule(static, 1)
|
|
for (int i = 0; i < thread_num; ++i) {
|
|
int64_t offset = i * per_thread_elem_num;
|
|
int64_t end = std::min(elem_num, offset + per_thread_elem_num);
|
|
int64_t curr_elem_num =
|
|
std::min(max_per_thread_iteration_elem_num, end - offset);
|
|
ThreadSHMContext* thread_ctx = ctx + i;
|
|
|
|
while (curr_elem_num > 0) {
|
|
inner_func(thread_ctx, offset, curr_elem_num);
|
|
|
|
offset += max_per_thread_iteration_elem_num;
|
|
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
|
|
}
|
|
}
|
|
}
|
|
}; // namespace shm_cc_ops
|
|
|
|
namespace shm_cc_ops {
|
|
|
|
void memcpy_from_shm(void* dst, void* src, const int64_t bytes) {
|
|
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
|
|
int64_t i = 0;
|
|
#pragma GCC unroll 4
|
|
for (; i < aligned_bytes; i += 64) {
|
|
vec_op::INT8Vec64 data(
|
|
true, (int8_t*)src + i); // stream loading shm to avoid caching
|
|
data.save((int8_t*)dst + i);
|
|
}
|
|
if (aligned_bytes < bytes) {
|
|
vec_op::INT8Vec64 data(true, (int8_t*)src + aligned_bytes);
|
|
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
|
|
}
|
|
}
|
|
|
|
void memcpy_to_shm(void* dst, void* src, const int64_t bytes) {
|
|
#pragma GCC unroll 4
|
|
for (int64_t i = 0; i < bytes; i += 64) {
|
|
vec_op::INT8Vec64 data((int8_t*)src + i);
|
|
data.nt_save((int8_t*)dst + i);
|
|
}
|
|
}
|
|
|
|
void memcpy(void* dst, void* src, const int64_t bytes) {
|
|
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
|
|
int64_t i = 0;
|
|
#pragma GCC unroll 4
|
|
for (; i < aligned_bytes; i += 64) {
|
|
vec_op::INT8Vec64 data((int8_t*)src + i);
|
|
data.save((int8_t*)dst + i);
|
|
}
|
|
if (aligned_bytes < bytes) {
|
|
vec_op::INT8Vec64 data((int8_t*)src + aligned_bytes);
|
|
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
|
|
}
|
|
}
|
|
|
|
template <typename scalar_t, int RANKS>
|
|
void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
|
size_t elem_num) {
|
|
CPU_KERNEL_GUARD_IN(all_reduce_sum_impl)
|
|
using vec_t = typename KernelVecType<scalar_t>::scalar_vec_t;
|
|
constexpr int64_t vec_elem_num = vec_t::get_elem_num();
|
|
const int worldsize = ctx->group_size;
|
|
|
|
shm_cc_ops::shm_cc_loop<scalar_t>(
|
|
ctx, elem_num,
|
|
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
|
int64_t data_elem_num) {
|
|
int rank = thread_ctx->rank;
|
|
scalar_t* thread_shm_ptr =
|
|
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
|
scalar_t* thread_data_ptr = data + data_offset;
|
|
int64_t thread_data_elem_num = data_elem_num * sizeof(scalar_t);
|
|
|
|
scalar_t* remote_data_ptrs[RANKS - 1];
|
|
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
|
remote_data_ptrs[idx] = thread_ctx->get_thread_shm_ptr<scalar_t>(
|
|
thread_ctx->get_swizzled_rank(idx + 1));
|
|
});
|
|
|
|
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
|
|
|
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
|
|
thread_data_elem_num);
|
|
|
|
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
|
|
|
int64_t aligned_data_elem_num =
|
|
(data_elem_num / vec_elem_num) * vec_elem_num;
|
|
int64_t i = 0;
|
|
#pragma GCC unroll 4
|
|
for (; i < aligned_data_elem_num; i += vec_elem_num) {
|
|
vec_t local_data(thread_data_ptr + i); // load from cache
|
|
vec_op::FP32Vec16 local_data_fp32(local_data);
|
|
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
|
vec_t remote_data(
|
|
true, remote_data_ptrs[idx] + i); // stream load from shm
|
|
vec_op::FP32Vec16 remote_data_fp32(remote_data);
|
|
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
|
|
});
|
|
vec_t reduced_data(local_data_fp32);
|
|
reduced_data.save(thread_data_ptr + i);
|
|
}
|
|
|
|
if (i < data_elem_num) {
|
|
vec_t local_data(thread_data_ptr + i); // load from cache
|
|
vec_op::FP32Vec16 local_data_fp32(local_data);
|
|
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
|
vec_t remote_data(
|
|
true, remote_data_ptrs[idx] + i); // stream load from shm
|
|
vec_op::FP32Vec16 remote_data_fp32(remote_data);
|
|
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
|
|
});
|
|
vec_t reduced_data(local_data_fp32);
|
|
reduced_data.save(thread_data_ptr + i,
|
|
data_elem_num - aligned_data_elem_num);
|
|
}
|
|
|
|
thread_ctx->barrier(ThreadSHMStat::DONE);
|
|
});
|
|
|
|
return;
|
|
}
|
|
}; // namespace shm_cc_ops
|
|
|
|
std::vector<std::unique_ptr<SHMManager>> SHMManager::SingletonInstances = {};
|
|
std::mutex SHMManager::SingletonInstancesLock = {};
|
|
|
|
template <typename scalar_t>
|
|
void shm_allreduce_sum(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num) {
|
|
switch (ctx->group_size) {
|
|
case 2:
|
|
shm_cc_ops::all_reduce_sum_impl<scalar_t, 2>(ctx, data, elem_num);
|
|
break;
|
|
case 3:
|
|
shm_cc_ops::all_reduce_sum_impl<scalar_t, 3>(ctx, data, elem_num);
|
|
break;
|
|
case 4:
|
|
shm_cc_ops::all_reduce_sum_impl<scalar_t, 4>(ctx, data, elem_num);
|
|
break;
|
|
case 8:
|
|
shm_cc_ops::all_reduce_sum_impl<scalar_t, 8>(ctx, data, elem_num);
|
|
break;
|
|
default:
|
|
TORCH_CHECK(false,
|
|
"Invalid world size: " + std::to_string(ctx->group_size));
|
|
}
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
|
scalar_t** outputs, const int dst) {
|
|
CPU_KERNEL_GUARD_IN(shm_gather_impl)
|
|
const int worldsize = ctx->group_size;
|
|
TORCH_CHECK_LT(dst, worldsize);
|
|
shm_cc_ops::shm_cc_loop<scalar_t>(
|
|
ctx, elem_num,
|
|
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
|
int64_t data_elem_num) {
|
|
int rank = thread_ctx->rank;
|
|
scalar_t* thread_shm_ptr =
|
|
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
|
|
|
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
|
|
|
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset,
|
|
data_elem_num * sizeof(scalar_t));
|
|
|
|
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
|
|
|
if (rank == dst) {
|
|
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
|
|
data_elem_num * sizeof(scalar_t));
|
|
for (int i = 1; i < worldsize; ++i) {
|
|
int src_rank = thread_ctx->get_swizzled_rank(i);
|
|
scalar_t* src_ptr =
|
|
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
|
|
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
|
|
shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr,
|
|
data_elem_num * sizeof(scalar_t));
|
|
}
|
|
}
|
|
|
|
thread_ctx->barrier(ThreadSHMStat::DONE);
|
|
});
|
|
|
|
return;
|
|
}
|
|
|
|
struct MemPiece {
|
|
void* ptr;
|
|
int64_t size;
|
|
|
|
template <typename T>
|
|
T* data_ptr() {
|
|
return reinterpret_cast<T*>(ptr);
|
|
}
|
|
};
|
|
|
|
struct TensorListMeta {
|
|
int64_t tensor_bytes[MAX_P2P_SEND_TENSOR_NUM];
|
|
torch::ScalarType tensor_types[MAX_P2P_SEND_TENSOR_NUM];
|
|
int64_t tensor_num;
|
|
int64_t total_bytes;
|
|
|
|
TensorListMeta() : tensor_num(0), total_bytes(0) {
|
|
static_assert(sizeof(TensorListMeta) % 64 == 0);
|
|
static_assert(sizeof(TensorListMeta) <
|
|
MIN_THREAD_PROCESS_SIZE); // To ensure the metadata always
|
|
// hold by the thread 0
|
|
for (int i = 0; i < MAX_P2P_SEND_TENSOR_NUM; ++i) {
|
|
tensor_bytes[i] = 0;
|
|
tensor_ptrs[i] = nullptr;
|
|
tensor_types[i] = torch::ScalarType::Undefined;
|
|
}
|
|
}
|
|
|
|
// For send and recv
|
|
void bind_tensor_list(std::vector<torch::Tensor>& tensor_list) {
|
|
TORCH_CHECK(tensor_types[0] == torch::ScalarType::Undefined,
|
|
"Re-bind TensorListMeta is not allowed.")
|
|
TORCH_CHECK_LE(tensor_list.size(), MAX_P2P_SEND_TENSOR_NUM);
|
|
tensor_num = tensor_list.size();
|
|
int64_t bytes_sum = 0;
|
|
for (int i = 0; i < tensor_list.size(); ++i) {
|
|
torch::Tensor& t = tensor_list[i];
|
|
TORCH_CHECK(t.is_contiguous());
|
|
tensor_bytes[i] = t.nbytes();
|
|
tensor_types[i] = t.scalar_type();
|
|
tensor_ptrs[i] = t.data_ptr();
|
|
bytes_sum += t.nbytes();
|
|
}
|
|
total_bytes = bytes_sum;
|
|
}
|
|
|
|
// For recv
|
|
std::vector<torch::Tensor> generate_tensor_list() {
|
|
std::vector<torch::Tensor> tensor_list;
|
|
tensor_list.reserve(tensor_num);
|
|
|
|
for (int i = 0; i < tensor_num; ++i) {
|
|
int64_t bytes = tensor_bytes[i];
|
|
auto type = tensor_types[i];
|
|
int64_t elem_bytes = torch::elementSize(type);
|
|
|
|
TORCH_CHECK_EQ(bytes % elem_bytes, 0);
|
|
int64_t elem_num = bytes / elem_bytes;
|
|
auto options = torch::TensorOptions().dtype(type).device(torch::kCPU);
|
|
tensor_list.emplace_back(torch::empty({elem_num}, options));
|
|
}
|
|
return tensor_list;
|
|
}
|
|
|
|
MemPiece get_data(int64_t offset) {
|
|
for (int i = 0; i < tensor_num; ++i) {
|
|
if (offset < tensor_bytes[i]) {
|
|
return {reinterpret_cast<int8_t*>(tensor_ptrs[i]) + offset,
|
|
tensor_bytes[i] - offset};
|
|
}
|
|
offset -= tensor_bytes[i];
|
|
}
|
|
return {nullptr, 0};
|
|
}
|
|
|
|
private:
|
|
void* tensor_ptrs[MAX_P2P_SEND_TENSOR_NUM];
|
|
int8_t _padding[40];
|
|
};
|
|
|
|
void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
|
const std::vector<torch::Tensor>& tensor_list) {
|
|
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
|
|
std::vector<torch::Tensor> tensor_list_with_metadata;
|
|
tensor_list_with_metadata.reserve(1 + tensor_list.size());
|
|
|
|
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
|
|
tensor_list_with_metadata.emplace_back(
|
|
torch::empty({sizeof(TensorListMeta)}, options));
|
|
tensor_list_with_metadata.insert(tensor_list_with_metadata.end(),
|
|
tensor_list.begin(), tensor_list.end());
|
|
|
|
torch::Tensor& metadata_tensor = tensor_list_with_metadata[0];
|
|
TORCH_CHECK_EQ(metadata_tensor.nbytes(), sizeof(TensorListMeta));
|
|
|
|
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
|
|
metadata->bind_tensor_list(tensor_list_with_metadata);
|
|
|
|
shm_cc_ops::shm_cc_loop<int8_t>(
|
|
ctx, metadata->total_bytes,
|
|
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
|
int64_t data_elem_num) {
|
|
int rank = thread_ctx->rank;
|
|
// Wait until the receiver set the stat to DONE
|
|
thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY);
|
|
|
|
int64_t curr_shm_offset = 0;
|
|
while (curr_shm_offset < data_elem_num) {
|
|
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
|
|
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
|
shm_cc_ops::memcpy(
|
|
thread_ctx->get_thread_shm_ptr<int8_t>(rank) + curr_shm_offset,
|
|
frag.ptr, frag.size);
|
|
curr_shm_offset += frag.size;
|
|
}
|
|
|
|
thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY);
|
|
});
|
|
}
|
|
|
|
std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
|
int64_t src) {
|
|
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list_impl)
|
|
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
|
|
torch::Tensor metadata_tensor =
|
|
torch::empty({sizeof(TensorListMeta)}, options);
|
|
|
|
// Wait until the sender set the stat of the thread 0 to SHM_DATA_READY
|
|
ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
|
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
|
ctx->get_thread_shm_ptr<void>(src),
|
|
sizeof(TensorListMeta));
|
|
TensorListMeta* src_metadata =
|
|
reinterpret_cast<TensorListMeta*>(metadata_tensor.data_ptr());
|
|
std::vector<torch::Tensor> tensor_list_with_metadata =
|
|
src_metadata->generate_tensor_list();
|
|
|
|
TensorListMeta metadata;
|
|
metadata.bind_tensor_list(tensor_list_with_metadata);
|
|
TORCH_CHECK_EQ(metadata.tensor_num, src_metadata->tensor_num);
|
|
TORCH_CHECK_EQ(metadata.total_bytes, src_metadata->total_bytes);
|
|
|
|
shm_cc_ops::shm_cc_loop<int8_t>(
|
|
ctx, metadata.total_bytes,
|
|
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
|
int64_t data_elem_num) {
|
|
// Wait until the sender set the stat to SHM_DATA_READY
|
|
thread_ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
|
int64_t curr_shm_offset = 0;
|
|
while (curr_shm_offset < data_elem_num) {
|
|
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
|
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
|
shm_cc_ops::memcpy(
|
|
frag.ptr,
|
|
thread_ctx->get_thread_shm_ptr<int8_t>(src) + curr_shm_offset,
|
|
frag.size);
|
|
curr_shm_offset += frag.size;
|
|
}
|
|
|
|
thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE);
|
|
});
|
|
|
|
std::vector<torch::Tensor> tensor_list;
|
|
tensor_list.reserve(metadata.tensor_num - 1);
|
|
tensor_list.insert(tensor_list.begin(), tensor_list_with_metadata.begin() + 1,
|
|
tensor_list_with_metadata.end());
|
|
|
|
return tensor_list;
|
|
}
|
|
} // namespace
|
|
|
|
void shm_gather(int64_t handle, torch::Tensor& data,
|
|
const std::optional<std::vector<torch::Tensor>>& outputs,
|
|
int64_t dst) {
|
|
TORCH_CHECK(data.is_contiguous())
|
|
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_gather_impl", [&] {
|
|
CPU_KERNEL_GUARD_IN(shm_gather_impl)
|
|
|
|
if (outputs.has_value()) {
|
|
TORCH_CHECK_LE(outputs->size(), MAX_SHM_RANK_NUM);
|
|
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
|
|
for (int i = 0; i < outputs->size(); ++i) {
|
|
output_ptrs[i] = outputs->at(i).data_ptr<scalar_t>();
|
|
}
|
|
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
|
data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
|
|
dst);
|
|
} else {
|
|
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
|
data.data_ptr<scalar_t>(), data.numel(), (scalar_t**)(0),
|
|
dst);
|
|
}
|
|
|
|
CPU_KERNEL_GUARD_OUT(shm_gather_impl)
|
|
});
|
|
}
|
|
|
|
void shm_all_gather(int64_t handle, const torch::Tensor& data,
|
|
torch::Tensor& output) {
|
|
TORCH_CHECK(data.is_contiguous())
|
|
TORCH_CHECK(output.is_contiguous())
|
|
|
|
const int64_t input_elem_num = data.numel();
|
|
const int64_t output_elem_num = output.numel();
|
|
TORCH_CHECK_EQ(output_elem_num % input_elem_num, 0);
|
|
const int world_size = output_elem_num / input_elem_num;
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_all_gather_impl", [&] {
|
|
CPU_KERNEL_GUARD_IN(shm_all_gather_impl)
|
|
auto ctx = SHMManager::get_singleton_instance(handle)->get_shm_ctx();
|
|
TORCH_CHECK_EQ(ctx->group_size, world_size);
|
|
|
|
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
|
|
for (int i = 0; i < world_size; ++i) {
|
|
output_ptrs[i] = output.data_ptr<scalar_t>() + i * input_elem_num;
|
|
}
|
|
shm_gather_impl(ctx, data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
|
|
ctx->rank);
|
|
CPU_KERNEL_GUARD_OUT(shm_all_gather_impl)
|
|
});
|
|
}
|
|
|
|
void shm_allreduce(int64_t handle, torch::Tensor& data) {
|
|
TORCH_CHECK(data.is_contiguous())
|
|
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_allreduce_sum", [&] {
|
|
CPU_KERNEL_GUARD_IN(shm_allreduce_sum)
|
|
shm_allreduce_sum(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
|
data.data_ptr<scalar_t>(), data.numel());
|
|
CPU_KERNEL_GUARD_OUT(shm_allreduce_sum)
|
|
});
|
|
}
|
|
|
|
void shm_send_tensor_list(int64_t handle,
|
|
const std::vector<torch::Tensor>& tensor_list,
|
|
int64_t dst) {
|
|
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
|
|
shm_send_tensor_list_impl(
|
|
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list);
|
|
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
|
|
}
|
|
|
|
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src) {
|
|
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list)
|
|
auto tensor_list = shm_recv_tensor_list_impl(
|
|
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), src);
|
|
CPU_KERNEL_GUARD_OUT(shm_recv_tensor_list)
|
|
return tensor_list;
|
|
}
|
|
|
|
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
|
|
const int64_t rank) {
|
|
return SHMManager::create_singleton_instance(name, group_size, rank);
|
|
}
|
|
|
|
std::string join_shm_manager(int64_t handle, const std::string& name) {
|
|
auto shm_manager = SHMManager::get_singleton_instance(handle);
|
|
TORCH_CHECK(shm_manager);
|
|
shm_manager->join(name);
|
|
return shm_manager->get_shm_ctx()->to_string();
|
|
} |