[CPU][Bugfix] Using custom allreduce for CPU backend (#15934)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
cefb9e5a28
commit
550b2801ad
@ -197,6 +197,7 @@ set(VLLM_EXT_SRC
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
|
@ -78,9 +78,14 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
// normal load
|
||||
explicit FP16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
// non-temproal load
|
||||
explicit FP16Vec16(bool, void* ptr)
|
||||
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
@ -110,9 +115,14 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
// normal load
|
||||
explicit BF16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
// non-temproal load
|
||||
explicit BF16Vec16(bool, void* ptr)
|
||||
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
@ -313,8 +323,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
||||
|
||||
// normal load
|
||||
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||
|
||||
// non-temproal load
|
||||
explicit FP32Vec16(bool, void* ptr)
|
||||
: reg((__m512)_mm512_stream_load_si512(ptr)) {}
|
||||
|
||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
@ -547,6 +562,33 @@ struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
_mm_mask_storeu_epi8(ptr, mask, reg);
|
||||
}
|
||||
};
|
||||
|
||||
struct INT8Vec64 : public Vec<INT8Vec64> {
|
||||
constexpr static int VEC_ELEM_NUM = 64;
|
||||
union AliasReg {
|
||||
__m512i reg;
|
||||
int8_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
__m512i reg;
|
||||
|
||||
// normal load
|
||||
explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {}
|
||||
|
||||
// non-temproal load
|
||||
explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {}
|
||||
|
||||
void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); }
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
constexpr uint64_t M = 0xFFFFFFFFFFFFFFFF;
|
||||
__mmask64 mask = _cvtu64_mask64(M >> (64 - elem_num));
|
||||
_mm512_mask_storeu_epi8(ptr, mask, reg);
|
||||
}
|
||||
|
||||
// non-temproal save
|
||||
void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); }
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
@ -657,6 +699,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
|
||||
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
|
||||
#ifdef __AVX512F__
|
||||
inline void non_temporal_save(FP16Vec16& vec, void* ptr) {
|
||||
_mm256_stream_si256((__m256i*)ptr, vec.reg);
|
||||
}
|
||||
inline void non_temporal_save(BF16Vec32& vec, void* ptr) {
|
||||
_mm512_stream_si512((__m512i*)ptr, vec.reg);
|
||||
}
|
||||
inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
|
||||
_mm256_stream_si256((__m256i*)ptr, vec.reg);
|
||||
}
|
||||
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
|
||||
_mm512_stream_ps((float*)ptr, vec.reg);
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void mem_barrier() { _mm_mfence(); }
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
|
781
csrc/cpu/shm.cpp
Normal file
781
csrc/cpu/shm.cpp
Normal file
@ -0,0 +1,781 @@
|
||||
#include "cpu/cpu_types.hpp"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
namespace {
|
||||
#define MAX_SHM_RANK_NUM 8
|
||||
#define MAX_THREAD_NUM 12
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||
#define MIN_THREAD_PROCESS_SIZE (8 * 1024)
|
||||
#define MAX_P2P_SEND_TENSOR_NUM 8
|
||||
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using scalar_vec_t = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using scalar_vec_t = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using scalar_vec_t = vec_op::BF16Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
using scalar_vec_t = vec_op::FP16Vec16;
|
||||
};
|
||||
|
||||
enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE };
|
||||
|
||||
struct ThreadSHMContext {
|
||||
volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM];
|
||||
int thread_id;
|
||||
int thread_num;
|
||||
int rank;
|
||||
int group_size;
|
||||
size_t _spinning_count;
|
||||
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
||||
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
||||
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
||||
|
||||
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
||||
const int group_size, void* thread_shm_ptr)
|
||||
: thread_id(thread_id),
|
||||
thread_num(thread_num),
|
||||
rank(rank),
|
||||
group_size(group_size),
|
||||
_spinning_count(0) {
|
||||
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
|
||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK((size_t)this % 64 == 0);
|
||||
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
||||
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||
shm_contexts[i] = nullptr;
|
||||
thread_shm_ptrs[i] = nullptr;
|
||||
swizzled_ranks[i] = (i + rank) % group_size;
|
||||
thread_stats[i] = ThreadSHMStat::DONE;
|
||||
}
|
||||
set_context(rank, this, thread_shm_ptr);
|
||||
}
|
||||
|
||||
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
|
||||
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK(ptr);
|
||||
TORCH_CHECK(thread_shm_ptr);
|
||||
TORCH_CHECK_EQ(ptr->thread_num, thread_num);
|
||||
TORCH_CHECK_EQ(ptr->thread_id, thread_id);
|
||||
shm_contexts[rank] = ptr;
|
||||
thread_shm_ptrs[rank] = thread_shm_ptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* get_thread_shm_ptr(int rank) {
|
||||
return reinterpret_cast<T*>(thread_shm_ptrs[rank]);
|
||||
}
|
||||
|
||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||
|
||||
void wait_for_all(ThreadSHMStat prev_stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void wait_for_one(int rank, ThreadSHMStat prev_stat) {
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void set_thread_stat(ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[this->rank] = stat;
|
||||
}
|
||||
}
|
||||
|
||||
void set_thread_stat(int target_rank, ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[target_rank] = stat;
|
||||
}
|
||||
}
|
||||
|
||||
// barrier for all ranks in the group, used for all2all ops
|
||||
// DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ...
|
||||
void barrier(ThreadSHMStat next_stat) {
|
||||
if (next_stat == ThreadSHMStat::THREAD_READY) {
|
||||
set_thread_stat(ThreadSHMStat::THREAD_READY);
|
||||
wait_for_all(ThreadSHMStat::DONE);
|
||||
} else if (next_stat == ThreadSHMStat::SHM_DATA_READY) {
|
||||
set_thread_stat(ThreadSHMStat::SHM_DATA_READY);
|
||||
wait_for_all(ThreadSHMStat::THREAD_READY);
|
||||
} else if (next_stat == ThreadSHMStat::DONE) {
|
||||
set_thread_stat(ThreadSHMStat::DONE);
|
||||
wait_for_all(ThreadSHMStat::SHM_DATA_READY);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid next_stat to barrier.");
|
||||
}
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
std::stringstream ss;
|
||||
ss << "SHMContext:";
|
||||
ss << "\nrank: " << rank;
|
||||
ss << "\ngroup_size: " << group_size;
|
||||
ss << "\nthread_num: " << thread_num;
|
||||
ss << "\nthread_id: " << thread_id;
|
||||
|
||||
ss << "\nshm_ctx_stat_loop_seq: [";
|
||||
for (int i = 0; i < group_size; ++i) {
|
||||
ss << swizzled_ranks[i] << ", ";
|
||||
}
|
||||
ss << "]";
|
||||
|
||||
ss << "\nshm_contexts: [";
|
||||
for (int i = 0; i < group_size; ++i) {
|
||||
if (shm_contexts[i]) {
|
||||
ss << shm_contexts[i]->rank << ", ";
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
class SHMManager {
|
||||
public:
|
||||
explicit SHMManager(const std::string& name, const int rank,
|
||||
const int group_size)
|
||||
: _rank(rank),
|
||||
_group_size(group_size),
|
||||
_thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)),
|
||||
_shm_names({""}),
|
||||
_shared_mem_ptrs({nullptr}),
|
||||
_shm_ctx(nullptr) {
|
||||
_shm_names[rank] = get_shm_name(name, rank);
|
||||
_shared_mem_ptrs[rank] = init_shm(rank);
|
||||
_shm_ctx = reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank]);
|
||||
|
||||
for (int i = 0; i < _thread_num; ++i) {
|
||||
ThreadSHMContext* ctx = new (_shm_ctx + i)
|
||||
ThreadSHMContext(i, _thread_num, _rank, _group_size,
|
||||
compute_thread_shm_ptr(_shm_ctx, i));
|
||||
}
|
||||
}
|
||||
|
||||
void join(const std::string& name) {
|
||||
for (int rank_idx = 0; rank_idx < _group_size; ++rank_idx) {
|
||||
if (rank_idx != _rank) {
|
||||
TORCH_CHECK(_shm_names[rank_idx].empty());
|
||||
TORCH_CHECK(_shared_mem_ptrs[rank_idx] == nullptr);
|
||||
_shm_names[rank_idx] = get_shm_name(name, rank_idx);
|
||||
_shared_mem_ptrs[rank_idx] = init_shm(rank_idx);
|
||||
ThreadSHMContext* target_ctx =
|
||||
reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank_idx]);
|
||||
for (int thread_idx = 0; thread_idx < _thread_num; ++thread_idx) {
|
||||
_shm_ctx[thread_idx].set_context(
|
||||
rank_idx, target_ctx + thread_idx,
|
||||
compute_thread_shm_ptr(target_ctx, thread_idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~SHMManager() { destroy_shm(); }
|
||||
|
||||
ThreadSHMContext* get_shm_ctx() const { return _shm_ctx; }
|
||||
|
||||
static std::string get_shm_name(const std::string& name, int rank) {
|
||||
return name + "_" + std::to_string(rank);
|
||||
}
|
||||
|
||||
static int64_t create_singleton_instance(const std::string& name,
|
||||
const int group_size,
|
||||
const int rank) {
|
||||
std::lock_guard<std::mutex> guard(SingletonInstancesLock);
|
||||
SingletonInstances.emplace_back(
|
||||
std::make_unique<SHMManager>(name, rank, group_size));
|
||||
return static_cast<int64_t>(SingletonInstances.size() - 1);
|
||||
}
|
||||
|
||||
static SHMManager* get_singleton_instance(int64_t handle) {
|
||||
return SingletonInstances[handle].get();
|
||||
}
|
||||
|
||||
protected:
|
||||
static std::vector<std::unique_ptr<SHMManager>> SingletonInstances;
|
||||
static std::mutex SingletonInstancesLock;
|
||||
|
||||
private:
|
||||
static size_t round_to_alignment(size_t num) {
|
||||
return ((num + 63) / 64) * 64;
|
||||
}
|
||||
|
||||
int8_t* compute_thread_shm_ptr(ThreadSHMContext* ctx, int thread_id) {
|
||||
int8_t* thread_shm_ptr =
|
||||
reinterpret_cast<int8_t*>(ctx) +
|
||||
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
|
||||
return thread_shm_ptr +
|
||||
thread_id * round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES);
|
||||
}
|
||||
|
||||
size_t compute_shm_size() {
|
||||
const size_t rounded_rank_buffer_size =
|
||||
round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES) * _thread_num;
|
||||
const size_t rounded_thread_shm_ctx_size =
|
||||
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
|
||||
const size_t shm_size =
|
||||
rounded_thread_shm_ctx_size + rounded_rank_buffer_size;
|
||||
return shm_size;
|
||||
}
|
||||
|
||||
void* init_shm(int target_rank) {
|
||||
const std::string& shm_name = _shm_names[target_rank];
|
||||
const int local_rank = _rank;
|
||||
const size_t shm_size = compute_shm_size();
|
||||
|
||||
int fd = -1;
|
||||
if (local_rank == target_rank) {
|
||||
fd = shm_open(shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR,
|
||||
S_IRUSR | S_IWUSR);
|
||||
|
||||
if (fd == -1)
|
||||
TORCH_CHECK(false, "create shm in SHMManager failed. errno: " +
|
||||
std::to_string(errno));
|
||||
|
||||
if (ftruncate(fd, shm_size) == -1)
|
||||
TORCH_CHECK(false, "ftruncate in SHMManager failed. errno: " +
|
||||
std::to_string(errno));
|
||||
} else {
|
||||
fd = shm_open(shm_name.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
|
||||
|
||||
if (fd == -1)
|
||||
TORCH_CHECK(false, "open shm in SHMManager failed. errno: " +
|
||||
std::to_string(errno));
|
||||
}
|
||||
|
||||
void* shm_ptr = mmap(nullptr, shm_size, PROT_READ | PROT_WRITE,
|
||||
MAP_SHARED | MAP_POPULATE, fd, 0);
|
||||
|
||||
if (shm_ptr == MAP_FAILED) {
|
||||
TORCH_CHECK(false,
|
||||
"mmap in SHMManager failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
if (close(fd) != 0) {
|
||||
TORCH_CHECK(
|
||||
false, "close in SHMManager failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
TORCH_CHECK((size_t)shm_ptr % 64 == 0);
|
||||
|
||||
return shm_ptr;
|
||||
}
|
||||
|
||||
void destroy_shm() {
|
||||
std::stringstream ss;
|
||||
ss << "local rank " << _rank << ": [";
|
||||
for (int thread_id = 0; thread_id < _thread_num; ++thread_id) {
|
||||
ss << _shm_ctx[thread_id]._spinning_count << ", ";
|
||||
}
|
||||
ss << "]\n";
|
||||
|
||||
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||
if (_shared_mem_ptrs[i] != nullptr) {
|
||||
munmap(_shared_mem_ptrs[i], compute_shm_size());
|
||||
}
|
||||
|
||||
if (!_shm_names[i].empty()) {
|
||||
shm_unlink(_shm_names[i].c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int _rank;
|
||||
int _group_size;
|
||||
int _thread_num;
|
||||
std::array<std::string, MAX_SHM_RANK_NUM> _shm_names;
|
||||
std::array<void*, MAX_SHM_RANK_NUM> _shared_mem_ptrs;
|
||||
ThreadSHMContext* _shm_ctx;
|
||||
};
|
||||
|
||||
namespace shm_cc_ops {
|
||||
template <typename scalar_t, typename F>
|
||||
void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
int thread_num = ctx->thread_num;
|
||||
int64_t total_bytes = elem_num * sizeof(scalar_t);
|
||||
int64_t total_units_num =
|
||||
(total_bytes + MIN_THREAD_PROCESS_SIZE - 1) / MIN_THREAD_PROCESS_SIZE;
|
||||
int64_t per_thread_units_num =
|
||||
(total_units_num + thread_num - 1) / thread_num;
|
||||
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
|
||||
int64_t max_per_thread_iteration_elem_num =
|
||||
PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t);
|
||||
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int i = 0; i < thread_num; ++i) {
|
||||
int64_t offset = i * per_thread_elem_num;
|
||||
int64_t end = std::min(elem_num, offset + per_thread_elem_num);
|
||||
int64_t curr_elem_num =
|
||||
std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
ThreadSHMContext* thread_ctx = ctx + i;
|
||||
|
||||
while (curr_elem_num > 0) {
|
||||
inner_func(thread_ctx, offset, curr_elem_num);
|
||||
|
||||
offset += max_per_thread_iteration_elem_num;
|
||||
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // namespace shm_cc_ops
|
||||
|
||||
namespace shm_cc_ops {
|
||||
|
||||
void memcpy_from_shm(void* dst, void* src, const int64_t bytes) {
|
||||
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
|
||||
int64_t i = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_bytes; i += 64) {
|
||||
vec_op::INT8Vec64 data(
|
||||
true, (int8_t*)src + i); // stream loading shm to avoid caching
|
||||
data.save((int8_t*)dst + i);
|
||||
}
|
||||
if (aligned_bytes < bytes) {
|
||||
vec_op::INT8Vec64 data(true, (int8_t*)src + aligned_bytes);
|
||||
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
void memcpy_to_shm(void* dst, void* src, const int64_t bytes) {
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t i = 0; i < bytes; i += 64) {
|
||||
vec_op::INT8Vec64 data((int8_t*)src + i);
|
||||
data.nt_save((int8_t*)dst + i);
|
||||
}
|
||||
}
|
||||
|
||||
void memcpy(void* dst, void* src, const int64_t bytes) {
|
||||
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
|
||||
int64_t i = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_bytes; i += 64) {
|
||||
vec_op::INT8Vec64 data((int8_t*)src + i);
|
||||
data.save((int8_t*)dst + i);
|
||||
}
|
||||
if (aligned_bytes < bytes) {
|
||||
vec_op::INT8Vec64 data((int8_t*)src + aligned_bytes);
|
||||
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int RANKS>
|
||||
void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
size_t elem_num) {
|
||||
CPU_KERNEL_GUARD_IN(all_reduce_sum_impl)
|
||||
using vec_t = typename KernelVecType<scalar_t>::scalar_vec_t;
|
||||
constexpr int64_t vec_elem_num = vec_t::get_elem_num();
|
||||
const int worldsize = ctx->group_size;
|
||||
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
scalar_t* thread_data_ptr = data + data_offset;
|
||||
int64_t thread_data_elem_num = data_elem_num * sizeof(scalar_t);
|
||||
|
||||
scalar_t* remote_data_ptrs[RANKS - 1];
|
||||
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
||||
remote_data_ptrs[idx] = thread_ctx->get_thread_shm_ptr<scalar_t>(
|
||||
thread_ctx->get_swizzled_rank(idx + 1));
|
||||
});
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
|
||||
thread_data_elem_num);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
int64_t aligned_data_elem_num =
|
||||
(data_elem_num / vec_elem_num) * vec_elem_num;
|
||||
int64_t i = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_data_elem_num; i += vec_elem_num) {
|
||||
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||
vec_op::FP32Vec16 local_data_fp32(local_data);
|
||||
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
||||
vec_t remote_data(
|
||||
true, remote_data_ptrs[idx] + i); // stream load from shm
|
||||
vec_op::FP32Vec16 remote_data_fp32(remote_data);
|
||||
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
|
||||
});
|
||||
vec_t reduced_data(local_data_fp32);
|
||||
reduced_data.save(thread_data_ptr + i);
|
||||
}
|
||||
|
||||
if (i < data_elem_num) {
|
||||
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||
vec_op::FP32Vec16 local_data_fp32(local_data);
|
||||
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
||||
vec_t remote_data(
|
||||
true, remote_data_ptrs[idx] + i); // stream load from shm
|
||||
vec_op::FP32Vec16 remote_data_fp32(remote_data);
|
||||
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
|
||||
});
|
||||
vec_t reduced_data(local_data_fp32);
|
||||
reduced_data.save(thread_data_ptr + i,
|
||||
data_elem_num - aligned_data_elem_num);
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
}; // namespace shm_cc_ops
|
||||
|
||||
std::vector<std::unique_ptr<SHMManager>> SHMManager::SingletonInstances = {};
|
||||
std::mutex SHMManager::SingletonInstancesLock = {};
|
||||
|
||||
template <typename scalar_t>
|
||||
void shm_allreduce_sum(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num) {
|
||||
switch (ctx->group_size) {
|
||||
case 2:
|
||||
shm_cc_ops::all_reduce_sum_impl<scalar_t, 2>(ctx, data, elem_num);
|
||||
break;
|
||||
case 3:
|
||||
shm_cc_ops::all_reduce_sum_impl<scalar_t, 3>(ctx, data, elem_num);
|
||||
break;
|
||||
case 4:
|
||||
shm_cc_ops::all_reduce_sum_impl<scalar_t, 4>(ctx, data, elem_num);
|
||||
break;
|
||||
case 8:
|
||||
shm_cc_ops::all_reduce_sum_impl<scalar_t, 8>(ctx, data, elem_num);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
"Invalid world size: " + std::to_string(ctx->group_size));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
scalar_t** outputs, const int dst) {
|
||||
CPU_KERNEL_GUARD_IN(shm_gather_impl)
|
||||
const int worldsize = ctx->group_size;
|
||||
TORCH_CHECK_LT(dst, worldsize);
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
if (rank == dst) {
|
||||
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
for (int i = 1; i < worldsize; ++i) {
|
||||
int src_rank = thread_ctx->get_swizzled_rank(i);
|
||||
scalar_t* src_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
|
||||
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
|
||||
shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
}
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
struct MemPiece {
|
||||
void* ptr;
|
||||
int64_t size;
|
||||
|
||||
template <typename T>
|
||||
T* data_ptr() {
|
||||
return reinterpret_cast<T*>(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorListMeta {
|
||||
int64_t tensor_bytes[MAX_P2P_SEND_TENSOR_NUM];
|
||||
torch::ScalarType tensor_types[MAX_P2P_SEND_TENSOR_NUM];
|
||||
int64_t tensor_num;
|
||||
int64_t total_bytes;
|
||||
|
||||
TensorListMeta() : tensor_num(0), total_bytes(0) {
|
||||
static_assert(sizeof(TensorListMeta) % 64 == 0);
|
||||
static_assert(sizeof(TensorListMeta) <
|
||||
MIN_THREAD_PROCESS_SIZE); // To ensure the metadata always
|
||||
// hold by the thread 0
|
||||
for (int i = 0; i < MAX_P2P_SEND_TENSOR_NUM; ++i) {
|
||||
tensor_bytes[i] = 0;
|
||||
tensor_ptrs[i] = nullptr;
|
||||
tensor_types[i] = torch::ScalarType::Undefined;
|
||||
}
|
||||
}
|
||||
|
||||
// For send and recv
|
||||
void bind_tensor_list(std::vector<torch::Tensor>& tensor_list) {
|
||||
TORCH_CHECK(tensor_types[0] == torch::ScalarType::Undefined,
|
||||
"Re-bind TensorListMeta is not allowed.")
|
||||
TORCH_CHECK_LE(tensor_list.size(), MAX_P2P_SEND_TENSOR_NUM);
|
||||
tensor_num = tensor_list.size();
|
||||
int64_t bytes_sum = 0;
|
||||
for (int i = 0; i < tensor_list.size(); ++i) {
|
||||
torch::Tensor& t = tensor_list[i];
|
||||
TORCH_CHECK(t.is_contiguous());
|
||||
tensor_bytes[i] = t.nbytes();
|
||||
tensor_types[i] = t.scalar_type();
|
||||
tensor_ptrs[i] = t.data_ptr();
|
||||
bytes_sum += t.nbytes();
|
||||
}
|
||||
total_bytes = bytes_sum;
|
||||
}
|
||||
|
||||
// For recv
|
||||
std::vector<torch::Tensor> generate_tensor_list() {
|
||||
std::vector<torch::Tensor> tensor_list;
|
||||
tensor_list.reserve(tensor_num);
|
||||
|
||||
for (int i = 0; i < tensor_num; ++i) {
|
||||
int64_t bytes = tensor_bytes[i];
|
||||
auto type = tensor_types[i];
|
||||
int64_t elem_bytes = torch::elementSize(type);
|
||||
|
||||
TORCH_CHECK_EQ(bytes % elem_bytes, 0);
|
||||
int64_t elem_num = bytes / elem_bytes;
|
||||
auto options = torch::TensorOptions().dtype(type).device(torch::kCPU);
|
||||
tensor_list.emplace_back(torch::empty({elem_num}, options));
|
||||
}
|
||||
return tensor_list;
|
||||
}
|
||||
|
||||
MemPiece get_data(int64_t offset) {
|
||||
for (int i = 0; i < tensor_num; ++i) {
|
||||
if (offset < tensor_bytes[i]) {
|
||||
return {reinterpret_cast<int8_t*>(tensor_ptrs[i]) + offset,
|
||||
tensor_bytes[i] - offset};
|
||||
}
|
||||
offset -= tensor_bytes[i];
|
||||
}
|
||||
return {nullptr, 0};
|
||||
}
|
||||
|
||||
private:
|
||||
void* tensor_ptrs[MAX_P2P_SEND_TENSOR_NUM];
|
||||
int8_t _padding[40];
|
||||
};
|
||||
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
const std::vector<torch::Tensor>& tensor_list) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
|
||||
std::vector<torch::Tensor> tensor_list_with_metadata;
|
||||
tensor_list_with_metadata.reserve(1 + tensor_list.size());
|
||||
|
||||
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
|
||||
tensor_list_with_metadata.emplace_back(
|
||||
torch::empty({sizeof(TensorListMeta)}, options));
|
||||
tensor_list_with_metadata.insert(tensor_list_with_metadata.end(),
|
||||
tensor_list.begin(), tensor_list.end());
|
||||
|
||||
torch::Tensor& metadata_tensor = tensor_list_with_metadata[0];
|
||||
TORCH_CHECK_EQ(metadata_tensor.nbytes(), sizeof(TensorListMeta));
|
||||
|
||||
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
|
||||
metadata->bind_tensor_list(tensor_list_with_metadata);
|
||||
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata->total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int rank = thread_ctx->rank;
|
||||
// Wait until the receiver set the stat to DONE
|
||||
thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
|
||||
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||
shm_cc_ops::memcpy(
|
||||
thread_ctx->get_thread_shm_ptr<int8_t>(rank) + curr_shm_offset,
|
||||
frag.ptr, frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
int64_t src) {
|
||||
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list_impl)
|
||||
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
|
||||
torch::Tensor metadata_tensor =
|
||||
torch::empty({sizeof(TensorListMeta)}, options);
|
||||
|
||||
// Wait until the sender set the stat of the thread 0 to SHM_DATA_READY
|
||||
ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
||||
ctx->get_thread_shm_ptr<void>(src),
|
||||
sizeof(TensorListMeta));
|
||||
TensorListMeta* src_metadata =
|
||||
reinterpret_cast<TensorListMeta*>(metadata_tensor.data_ptr());
|
||||
std::vector<torch::Tensor> tensor_list_with_metadata =
|
||||
src_metadata->generate_tensor_list();
|
||||
|
||||
TensorListMeta metadata;
|
||||
metadata.bind_tensor_list(tensor_list_with_metadata);
|
||||
TORCH_CHECK_EQ(metadata.tensor_num, src_metadata->tensor_num);
|
||||
TORCH_CHECK_EQ(metadata.total_bytes, src_metadata->total_bytes);
|
||||
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata.total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
// Wait until the sender set the stat to SHM_DATA_READY
|
||||
thread_ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
||||
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||
shm_cc_ops::memcpy(
|
||||
frag.ptr,
|
||||
thread_ctx->get_thread_shm_ptr<int8_t>(src) + curr_shm_offset,
|
||||
frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
std::vector<torch::Tensor> tensor_list;
|
||||
tensor_list.reserve(metadata.tensor_num - 1);
|
||||
tensor_list.insert(tensor_list.begin(), tensor_list_with_metadata.begin() + 1,
|
||||
tensor_list_with_metadata.end());
|
||||
|
||||
return tensor_list;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void shm_gather(int64_t handle, torch::Tensor& data,
|
||||
const std::optional<std::vector<torch::Tensor>>& outputs,
|
||||
int64_t dst) {
|
||||
TORCH_CHECK(data.is_contiguous())
|
||||
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_gather_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(shm_gather_impl)
|
||||
|
||||
if (outputs.has_value()) {
|
||||
TORCH_CHECK_LE(outputs->size(), MAX_SHM_RANK_NUM);
|
||||
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
|
||||
for (int i = 0; i < outputs->size(); ++i) {
|
||||
output_ptrs[i] = outputs->at(i).data_ptr<scalar_t>();
|
||||
}
|
||||
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
||||
data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
|
||||
dst);
|
||||
} else {
|
||||
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
||||
data.data_ptr<scalar_t>(), data.numel(), (scalar_t**)(0),
|
||||
dst);
|
||||
}
|
||||
|
||||
CPU_KERNEL_GUARD_OUT(shm_gather_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void shm_all_gather(int64_t handle, const torch::Tensor& data,
|
||||
torch::Tensor& output) {
|
||||
TORCH_CHECK(data.is_contiguous())
|
||||
TORCH_CHECK(output.is_contiguous())
|
||||
|
||||
const int64_t input_elem_num = data.numel();
|
||||
const int64_t output_elem_num = output.numel();
|
||||
TORCH_CHECK_EQ(output_elem_num % input_elem_num, 0);
|
||||
const int world_size = output_elem_num / input_elem_num;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_all_gather_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(shm_all_gather_impl)
|
||||
auto ctx = SHMManager::get_singleton_instance(handle)->get_shm_ctx();
|
||||
TORCH_CHECK_EQ(ctx->group_size, world_size);
|
||||
|
||||
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
output_ptrs[i] = output.data_ptr<scalar_t>() + i * input_elem_num;
|
||||
}
|
||||
shm_gather_impl(ctx, data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
|
||||
ctx->rank);
|
||||
CPU_KERNEL_GUARD_OUT(shm_all_gather_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void shm_allreduce(int64_t handle, torch::Tensor& data) {
|
||||
TORCH_CHECK(data.is_contiguous())
|
||||
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_allreduce_sum", [&] {
|
||||
CPU_KERNEL_GUARD_IN(shm_allreduce_sum)
|
||||
shm_allreduce_sum(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
||||
data.data_ptr<scalar_t>(), data.numel());
|
||||
CPU_KERNEL_GUARD_OUT(shm_allreduce_sum)
|
||||
});
|
||||
}
|
||||
|
||||
void shm_send_tensor_list(int64_t handle,
|
||||
const std::vector<torch::Tensor>& tensor_list,
|
||||
int64_t dst) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
|
||||
shm_send_tensor_list_impl(
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list);
|
||||
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src) {
|
||||
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list)
|
||||
auto tensor_list = shm_recv_tensor_list_impl(
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), src);
|
||||
CPU_KERNEL_GUARD_OUT(shm_recv_tensor_list)
|
||||
return tensor_list;
|
||||
}
|
||||
|
||||
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
|
||||
const int64_t rank) {
|
||||
return SHMManager::create_singleton_instance(name, group_size, rank);
|
||||
}
|
||||
|
||||
std::string join_shm_manager(int64_t handle, const std::string& name) {
|
||||
auto shm_manager = SHMManager::get_singleton_instance(handle);
|
||||
TORCH_CHECK(shm_manager);
|
||||
shm_manager->join(name);
|
||||
return shm_manager->get_shm_ctx()->to_string();
|
||||
}
|
@ -22,6 +22,26 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||
|
||||
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
|
||||
const int64_t rank);
|
||||
|
||||
std::string join_shm_manager(int64_t handle, const std::string& name);
|
||||
|
||||
void shm_allreduce(int64_t handle, torch::Tensor& data);
|
||||
|
||||
void shm_gather(int64_t handle, torch::Tensor& data,
|
||||
const std::optional<std::vector<torch::Tensor>>& outputs,
|
||||
int64_t dst);
|
||||
|
||||
void shm_all_gather(int64_t handle, const torch::Tensor& data,
|
||||
torch::Tensor& output);
|
||||
|
||||
void shm_send_tensor_list(int64_t handle,
|
||||
const std::vector<torch::Tensor>& tensor_list,
|
||||
int64_t dst);
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -131,6 +151,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#endif
|
||||
|
||||
// SHM CCL
|
||||
#ifdef __AVX512F__
|
||||
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
|
||||
&init_shm_manager);
|
||||
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
|
||||
ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
|
||||
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||
ops.def(
|
||||
"shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
|
||||
"()");
|
||||
ops.impl("shm_gather", torch::kCPU, &shm_gather);
|
||||
ops.def(
|
||||
"shm_all_gather(int handle, Tensor data, Tensor! output) -> "
|
||||
"()");
|
||||
ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather);
|
||||
ops.def(
|
||||
"shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
|
||||
"()");
|
||||
ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
|
||||
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
||||
&shm_recv_tensor_list);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
@ -18,7 +18,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
|
||||
#ifndef VLLM_NUMA_DISABLED
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str());
|
||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||
std::vector<int> omp_cpu_ids;
|
||||
omp_cpu_ids.reserve(omp_cpu_mask->size);
|
||||
|
@ -272,12 +272,14 @@ $ python examples/offline_inference/basic/basic.py
|
||||
|
||||
- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance.
|
||||
|
||||
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, two optimizations are to recommended: Tensor Parallel or Data Parallel.
|
||||
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance.
|
||||
|
||||
- Using Tensor Parallel for a latency constraints deployment: following GPU backend design, a Megatron-LM's parallel algorithm will be used to shard the model, based on the number of NUMA nodes (e.g. TP = 2 for a two NUMA node system). With [TP feature on CPU](gh-pr:6125) merged, Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
|
||||
- Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
|
||||
|
||||
```console
|
||||
VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp
|
||||
```
|
||||
|
||||
- Using Data Parallel for maximum throughput: to launch an LLM serving endpoint on each NUMA node along with one additional load balancer to dispatch the requests to those endpoints. Common solutions like [Nginx](#nginxloadbalancer) or HAProxy are recommended. Anyscale Ray project provides the feature on LLM [serving](https://docs.ray.io/en/latest/serve/index.html). Here is the example to setup a scalable LLM serving with [Ray Serve](https://github.com/intel/llm-on-ray/blob/main/docs/setup.inc.md).
|
||||
- For each thread id list in `VLLM_CPU_OMP_THREADS_BIND`, users should guarantee threads in the list belong to a same NUMA node.
|
||||
|
||||
- Meanwhile, users should also take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, TP worker will be killed due to out-of-memory.
|
||||
|
@ -1,10 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
@ -16,19 +20,120 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
self.ipex_available = False
|
||||
self.dist_module = torch.distributed
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
self.ipex_available = True
|
||||
self.dist_module = ipex.distributed
|
||||
except ImportError:
|
||||
"""
|
||||
Intel IPEX not found. Falling back to PyTorch native
|
||||
all_reduce for CPU (e.g. MacOS)
|
||||
"""
|
||||
pass
|
||||
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
|
||||
# Gather.
|
||||
self.dist_module.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
self.dist_module.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
self.communicator = communicator
|
||||
|
||||
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
||||
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
||||
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
handle = torch.ops._C.init_shm_manager(
|
||||
self.group_name,
|
||||
self.communicator.world_size,
|
||||
self.communicator.rank,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
torch.ops._C.join_shm_manager(
|
||||
handle,
|
||||
self.group_name,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
|
||||
return handle
|
||||
|
||||
def all_reduce(self,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
torch.ops._C.shm_allreduce(self.handle, input)
|
||||
|
||||
def gather(self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[List[torch.Tensor]],
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
torch.ops._C.shm_gather(self.handle, input, gather_list,
|
||||
torch.distributed.get_group_rank(group, dst))
|
||||
|
||||
def all_gather_into_tensor(self,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
torch.ops._C.shm_all_gather(self.handle, input, output)
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A CPU worker class."""
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
@ -139,6 +140,8 @@ class CPUWorker(LocalOrDistributedWorkerBase):
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
vllm_config.parallel_config.rank = rank
|
||||
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
@ -217,6 +220,10 @@ class CPUWorker(LocalOrDistributedWorkerBase):
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
# Note: unique identifier for creating allreduce shared memory
|
||||
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
|
||||
":")[-1]
|
||||
self.device = torch.device("cpu")
|
||||
self.init_distributed_environment()
|
||||
# Set random seed.
|
||||
|
Loading…
x
Reference in New Issue
Block a user