diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index b57d9e22..fdc03a79 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -197,6 +197,7 @@ set(VLLM_EXT_SRC if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC "csrc/cpu/quant.cpp" + "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) endif() diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 4568699b..cf67847b 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -78,9 +78,14 @@ struct FP16Vec16 : public Vec { __m256i reg; + // normal load explicit FP16Vec16(const void* ptr) : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} + // non-temproal load + explicit FP16Vec16(bool, void* ptr) + : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} + explicit FP16Vec16(const FP32Vec16&); void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } @@ -110,9 +115,14 @@ struct BF16Vec16 : public Vec { __m256i reg; + // normal load explicit BF16Vec16(const void* ptr) : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} + // non-temproal load + explicit BF16Vec16(bool, void* ptr) + : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} + explicit BF16Vec16(const FP32Vec16&); void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } @@ -313,8 +323,13 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + // normal load explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {} + // non-temproal load + explicit FP32Vec16(bool, void* ptr) + : reg((__m512)_mm512_stream_load_si512(ptr)) {} + explicit FP32Vec16(__m512 data) : reg(data) {} explicit FP32Vec16(const FP32Vec4& data) @@ -547,6 +562,33 @@ struct INT8Vec16 : public Vec { _mm_mask_storeu_epi8(ptr, mask, reg); } }; + +struct INT8Vec64 : public Vec { + constexpr static int VEC_ELEM_NUM = 64; + union AliasReg { + __m512i reg; + int8_t values[VEC_ELEM_NUM]; + }; + + __m512i reg; + + // normal load + explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {} + + // non-temproal load + explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {} + + void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); } + + void save(int8_t* ptr, const int elem_num) const { + constexpr uint64_t M = 0xFFFFFFFFFFFFFFFF; + __mmask64 mask = _cvtu64_mask64(M >> (64 - elem_num)); + _mm512_mask_storeu_epi8(ptr, mask, reg); + } + + // non-temproal save + void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); } +}; #endif template @@ -657,6 +699,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); } +#ifdef __AVX512F__ +inline void non_temporal_save(FP16Vec16& vec, void* ptr) { + _mm256_stream_si256((__m256i*)ptr, vec.reg); +} +inline void non_temporal_save(BF16Vec32& vec, void* ptr) { + _mm512_stream_si512((__m512i*)ptr, vec.reg); +} +inline void non_temporal_save(BF16Vec16& vec, void* ptr) { + _mm256_stream_si256((__m256i*)ptr, vec.reg); +} +inline void non_temporal_save(FP32Vec16& vec, void* ptr) { + _mm512_stream_ps((float*)ptr, vec.reg); +} +#endif + +inline void mem_barrier() { _mm_mfence(); } }; // namespace vec_op #endif diff --git a/csrc/cpu/shm.cpp b/csrc/cpu/shm.cpp new file mode 100644 index 00000000..f55e96de --- /dev/null +++ b/csrc/cpu/shm.cpp @@ -0,0 +1,781 @@ +#include "cpu/cpu_types.hpp" + +#include +#include +#include +#include + +namespace { +#define MAX_SHM_RANK_NUM 8 +#define MAX_THREAD_NUM 12 +#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024) +#define MIN_THREAD_PROCESS_SIZE (8 * 1024) +#define MAX_P2P_SEND_TENSOR_NUM 8 + +template +struct KernelVecType { + using scalar_vec_t = void; +}; + +template <> +struct KernelVecType { + using scalar_vec_t = vec_op::FP32Vec16; +}; + +template <> +struct KernelVecType { + using scalar_vec_t = vec_op::BF16Vec16; +}; + +template <> +struct KernelVecType { + using scalar_vec_t = vec_op::FP16Vec16; +}; + +enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE }; + +struct ThreadSHMContext { + volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM]; + int thread_id; + int thread_num; + int rank; + int group_size; + size_t _spinning_count; + int swizzled_ranks[MAX_SHM_RANK_NUM]; + void* thread_shm_ptrs[MAX_SHM_RANK_NUM]; + ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM]; + + ThreadSHMContext(const int thread_id, const int thread_num, const int rank, + const int group_size, void* thread_shm_ptr) + : thread_id(thread_id), + thread_num(thread_num), + rank(rank), + group_size(group_size), + _spinning_count(0) { + static_assert(sizeof(ThreadSHMContext) % 64 == 0); + TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); + TORCH_CHECK((size_t)this % 64 == 0); + TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0); + for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) { + shm_contexts[i] = nullptr; + thread_shm_ptrs[i] = nullptr; + swizzled_ranks[i] = (i + rank) % group_size; + thread_stats[i] = ThreadSHMStat::DONE; + } + set_context(rank, this, thread_shm_ptr); + } + + void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) { + TORCH_CHECK(rank < MAX_SHM_RANK_NUM); + TORCH_CHECK(ptr); + TORCH_CHECK(thread_shm_ptr); + TORCH_CHECK_EQ(ptr->thread_num, thread_num); + TORCH_CHECK_EQ(ptr->thread_id, thread_id); + shm_contexts[rank] = ptr; + thread_shm_ptrs[rank] = thread_shm_ptr; + } + + template + T* get_thread_shm_ptr(int rank) { + return reinterpret_cast(thread_shm_ptrs[rank]); + } + + int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } + + void wait_for_all(ThreadSHMStat prev_stat) { + for (int idx = 0; idx < group_size; ++idx) { + int rank = get_swizzled_rank(idx); + while (thread_stats[rank] == prev_stat) { + ++_spinning_count; + _mm_pause(); + } + } + vec_op::mem_barrier(); + } + + void wait_for_one(int rank, ThreadSHMStat prev_stat) { + while (thread_stats[rank] == prev_stat) { + ++_spinning_count; + _mm_pause(); + } + vec_op::mem_barrier(); + } + + void set_thread_stat(ThreadSHMStat stat) { + for (int idx = 0; idx < group_size; ++idx) { + int rank = get_swizzled_rank(idx); + shm_contexts[rank]->thread_stats[this->rank] = stat; + } + } + + void set_thread_stat(int target_rank, ThreadSHMStat stat) { + for (int idx = 0; idx < group_size; ++idx) { + int rank = get_swizzled_rank(idx); + shm_contexts[rank]->thread_stats[target_rank] = stat; + } + } + + // barrier for all ranks in the group, used for all2all ops + // DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ... + void barrier(ThreadSHMStat next_stat) { + if (next_stat == ThreadSHMStat::THREAD_READY) { + set_thread_stat(ThreadSHMStat::THREAD_READY); + wait_for_all(ThreadSHMStat::DONE); + } else if (next_stat == ThreadSHMStat::SHM_DATA_READY) { + set_thread_stat(ThreadSHMStat::SHM_DATA_READY); + wait_for_all(ThreadSHMStat::THREAD_READY); + } else if (next_stat == ThreadSHMStat::DONE) { + set_thread_stat(ThreadSHMStat::DONE); + wait_for_all(ThreadSHMStat::SHM_DATA_READY); + } else { + TORCH_CHECK(false, "Invalid next_stat to barrier."); + } + } + + std::string to_string() const { + std::stringstream ss; + ss << "SHMContext:"; + ss << "\nrank: " << rank; + ss << "\ngroup_size: " << group_size; + ss << "\nthread_num: " << thread_num; + ss << "\nthread_id: " << thread_id; + + ss << "\nshm_ctx_stat_loop_seq: ["; + for (int i = 0; i < group_size; ++i) { + ss << swizzled_ranks[i] << ", "; + } + ss << "]"; + + ss << "\nshm_contexts: ["; + for (int i = 0; i < group_size; ++i) { + if (shm_contexts[i]) { + ss << shm_contexts[i]->rank << ", "; + } + } + ss << "]"; + + return ss.str(); + } +}; + +class SHMManager { + public: + explicit SHMManager(const std::string& name, const int rank, + const int group_size) + : _rank(rank), + _group_size(group_size), + _thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)), + _shm_names({""}), + _shared_mem_ptrs({nullptr}), + _shm_ctx(nullptr) { + _shm_names[rank] = get_shm_name(name, rank); + _shared_mem_ptrs[rank] = init_shm(rank); + _shm_ctx = reinterpret_cast(_shared_mem_ptrs[rank]); + + for (int i = 0; i < _thread_num; ++i) { + ThreadSHMContext* ctx = new (_shm_ctx + i) + ThreadSHMContext(i, _thread_num, _rank, _group_size, + compute_thread_shm_ptr(_shm_ctx, i)); + } + } + + void join(const std::string& name) { + for (int rank_idx = 0; rank_idx < _group_size; ++rank_idx) { + if (rank_idx != _rank) { + TORCH_CHECK(_shm_names[rank_idx].empty()); + TORCH_CHECK(_shared_mem_ptrs[rank_idx] == nullptr); + _shm_names[rank_idx] = get_shm_name(name, rank_idx); + _shared_mem_ptrs[rank_idx] = init_shm(rank_idx); + ThreadSHMContext* target_ctx = + reinterpret_cast(_shared_mem_ptrs[rank_idx]); + for (int thread_idx = 0; thread_idx < _thread_num; ++thread_idx) { + _shm_ctx[thread_idx].set_context( + rank_idx, target_ctx + thread_idx, + compute_thread_shm_ptr(target_ctx, thread_idx)); + } + } + } + } + + ~SHMManager() { destroy_shm(); } + + ThreadSHMContext* get_shm_ctx() const { return _shm_ctx; } + + static std::string get_shm_name(const std::string& name, int rank) { + return name + "_" + std::to_string(rank); + } + + static int64_t create_singleton_instance(const std::string& name, + const int group_size, + const int rank) { + std::lock_guard guard(SingletonInstancesLock); + SingletonInstances.emplace_back( + std::make_unique(name, rank, group_size)); + return static_cast(SingletonInstances.size() - 1); + } + + static SHMManager* get_singleton_instance(int64_t handle) { + return SingletonInstances[handle].get(); + } + + protected: + static std::vector> SingletonInstances; + static std::mutex SingletonInstancesLock; + + private: + static size_t round_to_alignment(size_t num) { + return ((num + 63) / 64) * 64; + } + + int8_t* compute_thread_shm_ptr(ThreadSHMContext* ctx, int thread_id) { + int8_t* thread_shm_ptr = + reinterpret_cast(ctx) + + round_to_alignment(_thread_num * sizeof(ThreadSHMContext)); + return thread_shm_ptr + + thread_id * round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES); + } + + size_t compute_shm_size() { + const size_t rounded_rank_buffer_size = + round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES) * _thread_num; + const size_t rounded_thread_shm_ctx_size = + round_to_alignment(_thread_num * sizeof(ThreadSHMContext)); + const size_t shm_size = + rounded_thread_shm_ctx_size + rounded_rank_buffer_size; + return shm_size; + } + + void* init_shm(int target_rank) { + const std::string& shm_name = _shm_names[target_rank]; + const int local_rank = _rank; + const size_t shm_size = compute_shm_size(); + + int fd = -1; + if (local_rank == target_rank) { + fd = shm_open(shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR, + S_IRUSR | S_IWUSR); + + if (fd == -1) + TORCH_CHECK(false, "create shm in SHMManager failed. errno: " + + std::to_string(errno)); + + if (ftruncate(fd, shm_size) == -1) + TORCH_CHECK(false, "ftruncate in SHMManager failed. errno: " + + std::to_string(errno)); + } else { + fd = shm_open(shm_name.c_str(), O_RDWR, S_IRUSR | S_IWUSR); + + if (fd == -1) + TORCH_CHECK(false, "open shm in SHMManager failed. errno: " + + std::to_string(errno)); + } + + void* shm_ptr = mmap(nullptr, shm_size, PROT_READ | PROT_WRITE, + MAP_SHARED | MAP_POPULATE, fd, 0); + + if (shm_ptr == MAP_FAILED) { + TORCH_CHECK(false, + "mmap in SHMManager failed. errno: " + std::to_string(errno)); + } + + if (close(fd) != 0) { + TORCH_CHECK( + false, "close in SHMManager failed. errno: " + std::to_string(errno)); + } + + TORCH_CHECK((size_t)shm_ptr % 64 == 0); + + return shm_ptr; + } + + void destroy_shm() { + std::stringstream ss; + ss << "local rank " << _rank << ": ["; + for (int thread_id = 0; thread_id < _thread_num; ++thread_id) { + ss << _shm_ctx[thread_id]._spinning_count << ", "; + } + ss << "]\n"; + + for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) { + if (_shared_mem_ptrs[i] != nullptr) { + munmap(_shared_mem_ptrs[i], compute_shm_size()); + } + + if (!_shm_names[i].empty()) { + shm_unlink(_shm_names[i].c_str()); + } + } + } + + int _rank; + int _group_size; + int _thread_num; + std::array _shm_names; + std::array _shared_mem_ptrs; + ThreadSHMContext* _shm_ctx; +}; + +namespace shm_cc_ops { +template +void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { + int thread_num = ctx->thread_num; + int64_t total_bytes = elem_num * sizeof(scalar_t); + int64_t total_units_num = + (total_bytes + MIN_THREAD_PROCESS_SIZE - 1) / MIN_THREAD_PROCESS_SIZE; + int64_t per_thread_units_num = + (total_units_num + thread_num - 1) / thread_num; + int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t); + int64_t max_per_thread_iteration_elem_num = + PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t); + int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num; + +#pragma omp parallel for schedule(static, 1) + for (int i = 0; i < thread_num; ++i) { + int64_t offset = i * per_thread_elem_num; + int64_t end = std::min(elem_num, offset + per_thread_elem_num); + int64_t curr_elem_num = + std::min(max_per_thread_iteration_elem_num, end - offset); + ThreadSHMContext* thread_ctx = ctx + i; + + while (curr_elem_num > 0) { + inner_func(thread_ctx, offset, curr_elem_num); + + offset += max_per_thread_iteration_elem_num; + curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); + } + } +} +}; // namespace shm_cc_ops + +namespace shm_cc_ops { + +void memcpy_from_shm(void* dst, void* src, const int64_t bytes) { + const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned + int64_t i = 0; +#pragma GCC unroll 4 + for (; i < aligned_bytes; i += 64) { + vec_op::INT8Vec64 data( + true, (int8_t*)src + i); // stream loading shm to avoid caching + data.save((int8_t*)dst + i); + } + if (aligned_bytes < bytes) { + vec_op::INT8Vec64 data(true, (int8_t*)src + aligned_bytes); + data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes); + } +} + +void memcpy_to_shm(void* dst, void* src, const int64_t bytes) { +#pragma GCC unroll 4 + for (int64_t i = 0; i < bytes; i += 64) { + vec_op::INT8Vec64 data((int8_t*)src + i); + data.nt_save((int8_t*)dst + i); + } +} + +void memcpy(void* dst, void* src, const int64_t bytes) { + const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned + int64_t i = 0; +#pragma GCC unroll 4 + for (; i < aligned_bytes; i += 64) { + vec_op::INT8Vec64 data((int8_t*)src + i); + data.save((int8_t*)dst + i); + } + if (aligned_bytes < bytes) { + vec_op::INT8Vec64 data((int8_t*)src + aligned_bytes); + data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes); + } +} + +template +void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, + size_t elem_num) { + CPU_KERNEL_GUARD_IN(all_reduce_sum_impl) + using vec_t = typename KernelVecType::scalar_vec_t; + constexpr int64_t vec_elem_num = vec_t::get_elem_num(); + const int worldsize = ctx->group_size; + + shm_cc_ops::shm_cc_loop( + ctx, elem_num, + [&](ThreadSHMContext* thread_ctx, int64_t data_offset, + int64_t data_elem_num) { + int rank = thread_ctx->rank; + scalar_t* thread_shm_ptr = + thread_ctx->get_thread_shm_ptr(rank); + scalar_t* thread_data_ptr = data + data_offset; + int64_t thread_data_elem_num = data_elem_num * sizeof(scalar_t); + + scalar_t* remote_data_ptrs[RANKS - 1]; + vec_op::unroll_loop([&](int idx) { + remote_data_ptrs[idx] = thread_ctx->get_thread_shm_ptr( + thread_ctx->get_swizzled_rank(idx + 1)); + }); + + thread_ctx->barrier(ThreadSHMStat::THREAD_READY); + + shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr, + thread_data_elem_num); + + thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); + + int64_t aligned_data_elem_num = + (data_elem_num / vec_elem_num) * vec_elem_num; + int64_t i = 0; +#pragma GCC unroll 4 + for (; i < aligned_data_elem_num; i += vec_elem_num) { + vec_t local_data(thread_data_ptr + i); // load from cache + vec_op::FP32Vec16 local_data_fp32(local_data); + vec_op::unroll_loop([&](int idx) { + vec_t remote_data( + true, remote_data_ptrs[idx] + i); // stream load from shm + vec_op::FP32Vec16 remote_data_fp32(remote_data); + local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce + }); + vec_t reduced_data(local_data_fp32); + reduced_data.save(thread_data_ptr + i); + } + + if (i < data_elem_num) { + vec_t local_data(thread_data_ptr + i); // load from cache + vec_op::FP32Vec16 local_data_fp32(local_data); + vec_op::unroll_loop([&](int idx) { + vec_t remote_data( + true, remote_data_ptrs[idx] + i); // stream load from shm + vec_op::FP32Vec16 remote_data_fp32(remote_data); + local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce + }); + vec_t reduced_data(local_data_fp32); + reduced_data.save(thread_data_ptr + i, + data_elem_num - aligned_data_elem_num); + } + + thread_ctx->barrier(ThreadSHMStat::DONE); + }); + + return; +} +}; // namespace shm_cc_ops + +std::vector> SHMManager::SingletonInstances = {}; +std::mutex SHMManager::SingletonInstancesLock = {}; + +template +void shm_allreduce_sum(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num) { + switch (ctx->group_size) { + case 2: + shm_cc_ops::all_reduce_sum_impl(ctx, data, elem_num); + break; + case 3: + shm_cc_ops::all_reduce_sum_impl(ctx, data, elem_num); + break; + case 4: + shm_cc_ops::all_reduce_sum_impl(ctx, data, elem_num); + break; + case 8: + shm_cc_ops::all_reduce_sum_impl(ctx, data, elem_num); + break; + default: + TORCH_CHECK(false, + "Invalid world size: " + std::to_string(ctx->group_size)); + } +} + +template +void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, + scalar_t** outputs, const int dst) { + CPU_KERNEL_GUARD_IN(shm_gather_impl) + const int worldsize = ctx->group_size; + TORCH_CHECK_LT(dst, worldsize); + shm_cc_ops::shm_cc_loop( + ctx, elem_num, + [&](ThreadSHMContext* thread_ctx, int64_t data_offset, + int64_t data_elem_num) { + int rank = thread_ctx->rank; + scalar_t* thread_shm_ptr = + thread_ctx->get_thread_shm_ptr(rank); + + thread_ctx->barrier(ThreadSHMStat::THREAD_READY); + + shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset, + data_elem_num * sizeof(scalar_t)); + + thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); + + if (rank == dst) { + shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset, + data_elem_num * sizeof(scalar_t)); + for (int i = 1; i < worldsize; ++i) { + int src_rank = thread_ctx->get_swizzled_rank(i); + scalar_t* src_ptr = + thread_ctx->get_thread_shm_ptr(src_rank); // shm + scalar_t* dst_ptr = outputs[src_rank] + data_offset; + shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr, + data_elem_num * sizeof(scalar_t)); + } + } + + thread_ctx->barrier(ThreadSHMStat::DONE); + }); + + return; +} + +struct MemPiece { + void* ptr; + int64_t size; + + template + T* data_ptr() { + return reinterpret_cast(ptr); + } +}; + +struct TensorListMeta { + int64_t tensor_bytes[MAX_P2P_SEND_TENSOR_NUM]; + torch::ScalarType tensor_types[MAX_P2P_SEND_TENSOR_NUM]; + int64_t tensor_num; + int64_t total_bytes; + + TensorListMeta() : tensor_num(0), total_bytes(0) { + static_assert(sizeof(TensorListMeta) % 64 == 0); + static_assert(sizeof(TensorListMeta) < + MIN_THREAD_PROCESS_SIZE); // To ensure the metadata always + // hold by the thread 0 + for (int i = 0; i < MAX_P2P_SEND_TENSOR_NUM; ++i) { + tensor_bytes[i] = 0; + tensor_ptrs[i] = nullptr; + tensor_types[i] = torch::ScalarType::Undefined; + } + } + + // For send and recv + void bind_tensor_list(std::vector& tensor_list) { + TORCH_CHECK(tensor_types[0] == torch::ScalarType::Undefined, + "Re-bind TensorListMeta is not allowed.") + TORCH_CHECK_LE(tensor_list.size(), MAX_P2P_SEND_TENSOR_NUM); + tensor_num = tensor_list.size(); + int64_t bytes_sum = 0; + for (int i = 0; i < tensor_list.size(); ++i) { + torch::Tensor& t = tensor_list[i]; + TORCH_CHECK(t.is_contiguous()); + tensor_bytes[i] = t.nbytes(); + tensor_types[i] = t.scalar_type(); + tensor_ptrs[i] = t.data_ptr(); + bytes_sum += t.nbytes(); + } + total_bytes = bytes_sum; + } + + // For recv + std::vector generate_tensor_list() { + std::vector tensor_list; + tensor_list.reserve(tensor_num); + + for (int i = 0; i < tensor_num; ++i) { + int64_t bytes = tensor_bytes[i]; + auto type = tensor_types[i]; + int64_t elem_bytes = torch::elementSize(type); + + TORCH_CHECK_EQ(bytes % elem_bytes, 0); + int64_t elem_num = bytes / elem_bytes; + auto options = torch::TensorOptions().dtype(type).device(torch::kCPU); + tensor_list.emplace_back(torch::empty({elem_num}, options)); + } + return tensor_list; + } + + MemPiece get_data(int64_t offset) { + for (int i = 0; i < tensor_num; ++i) { + if (offset < tensor_bytes[i]) { + return {reinterpret_cast(tensor_ptrs[i]) + offset, + tensor_bytes[i] - offset}; + } + offset -= tensor_bytes[i]; + } + return {nullptr, 0}; + } + + private: + void* tensor_ptrs[MAX_P2P_SEND_TENSOR_NUM]; + int8_t _padding[40]; +}; + +void shm_send_tensor_list_impl(ThreadSHMContext* ctx, + const std::vector& tensor_list) { + CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl) + std::vector tensor_list_with_metadata; + tensor_list_with_metadata.reserve(1 + tensor_list.size()); + + auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU); + tensor_list_with_metadata.emplace_back( + torch::empty({sizeof(TensorListMeta)}, options)); + tensor_list_with_metadata.insert(tensor_list_with_metadata.end(), + tensor_list.begin(), tensor_list.end()); + + torch::Tensor& metadata_tensor = tensor_list_with_metadata[0]; + TORCH_CHECK_EQ(metadata_tensor.nbytes(), sizeof(TensorListMeta)); + + TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta(); + metadata->bind_tensor_list(tensor_list_with_metadata); + + shm_cc_ops::shm_cc_loop( + ctx, metadata->total_bytes, + [&](ThreadSHMContext* thread_ctx, int64_t data_offset, + int64_t data_elem_num) { + int rank = thread_ctx->rank; + // Wait until the receiver set the stat to DONE + thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY); + + int64_t curr_shm_offset = 0; + while (curr_shm_offset < data_elem_num) { + MemPiece frag = metadata->get_data(data_offset + curr_shm_offset); + frag.size = std::min(frag.size, data_elem_num - curr_shm_offset); + shm_cc_ops::memcpy( + thread_ctx->get_thread_shm_ptr(rank) + curr_shm_offset, + frag.ptr, frag.size); + curr_shm_offset += frag.size; + } + + thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY); + }); +} + +std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, + int64_t src) { + CPU_KERNEL_GUARD_IN(shm_recv_tensor_list_impl) + auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU); + torch::Tensor metadata_tensor = + torch::empty({sizeof(TensorListMeta)}, options); + + // Wait until the sender set the stat of the thread 0 to SHM_DATA_READY + ctx->wait_for_one(src, ThreadSHMStat::DONE); + shm_cc_ops::memcpy(metadata_tensor.data_ptr(), + ctx->get_thread_shm_ptr(src), + sizeof(TensorListMeta)); + TensorListMeta* src_metadata = + reinterpret_cast(metadata_tensor.data_ptr()); + std::vector tensor_list_with_metadata = + src_metadata->generate_tensor_list(); + + TensorListMeta metadata; + metadata.bind_tensor_list(tensor_list_with_metadata); + TORCH_CHECK_EQ(metadata.tensor_num, src_metadata->tensor_num); + TORCH_CHECK_EQ(metadata.total_bytes, src_metadata->total_bytes); + + shm_cc_ops::shm_cc_loop( + ctx, metadata.total_bytes, + [&](ThreadSHMContext* thread_ctx, int64_t data_offset, + int64_t data_elem_num) { + // Wait until the sender set the stat to SHM_DATA_READY + thread_ctx->wait_for_one(src, ThreadSHMStat::DONE); + int64_t curr_shm_offset = 0; + while (curr_shm_offset < data_elem_num) { + MemPiece frag = metadata.get_data(data_offset + curr_shm_offset); + frag.size = std::min(frag.size, data_elem_num - curr_shm_offset); + shm_cc_ops::memcpy( + frag.ptr, + thread_ctx->get_thread_shm_ptr(src) + curr_shm_offset, + frag.size); + curr_shm_offset += frag.size; + } + + thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE); + }); + + std::vector tensor_list; + tensor_list.reserve(metadata.tensor_num - 1); + tensor_list.insert(tensor_list.begin(), tensor_list_with_metadata.begin() + 1, + tensor_list_with_metadata.end()); + + return tensor_list; +} +} // namespace + +void shm_gather(int64_t handle, torch::Tensor& data, + const std::optional>& outputs, + int64_t dst) { + TORCH_CHECK(data.is_contiguous()) + VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_gather_impl", [&] { + CPU_KERNEL_GUARD_IN(shm_gather_impl) + + if (outputs.has_value()) { + TORCH_CHECK_LE(outputs->size(), MAX_SHM_RANK_NUM); + scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr}; + for (int i = 0; i < outputs->size(); ++i) { + output_ptrs[i] = outputs->at(i).data_ptr(); + } + shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(), + data.data_ptr(), data.numel(), output_ptrs, + dst); + } else { + shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(), + data.data_ptr(), data.numel(), (scalar_t**)(0), + dst); + } + + CPU_KERNEL_GUARD_OUT(shm_gather_impl) + }); +} + +void shm_all_gather(int64_t handle, const torch::Tensor& data, + torch::Tensor& output) { + TORCH_CHECK(data.is_contiguous()) + TORCH_CHECK(output.is_contiguous()) + + const int64_t input_elem_num = data.numel(); + const int64_t output_elem_num = output.numel(); + TORCH_CHECK_EQ(output_elem_num % input_elem_num, 0); + const int world_size = output_elem_num / input_elem_num; + + VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_all_gather_impl", [&] { + CPU_KERNEL_GUARD_IN(shm_all_gather_impl) + auto ctx = SHMManager::get_singleton_instance(handle)->get_shm_ctx(); + TORCH_CHECK_EQ(ctx->group_size, world_size); + + scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr}; + for (int i = 0; i < world_size; ++i) { + output_ptrs[i] = output.data_ptr() + i * input_elem_num; + } + shm_gather_impl(ctx, data.data_ptr(), data.numel(), output_ptrs, + ctx->rank); + CPU_KERNEL_GUARD_OUT(shm_all_gather_impl) + }); +} + +void shm_allreduce(int64_t handle, torch::Tensor& data) { + TORCH_CHECK(data.is_contiguous()) + VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_allreduce_sum", [&] { + CPU_KERNEL_GUARD_IN(shm_allreduce_sum) + shm_allreduce_sum(SHMManager::get_singleton_instance(handle)->get_shm_ctx(), + data.data_ptr(), data.numel()); + CPU_KERNEL_GUARD_OUT(shm_allreduce_sum) + }); +} + +void shm_send_tensor_list(int64_t handle, + const std::vector& tensor_list, + int64_t dst) { + CPU_KERNEL_GUARD_IN(shm_send_tensor_list) + shm_send_tensor_list_impl( + SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list); + CPU_KERNEL_GUARD_OUT(shm_send_tensor_list) +} + +std::vector shm_recv_tensor_list(int64_t handle, int64_t src) { + CPU_KERNEL_GUARD_IN(shm_recv_tensor_list) + auto tensor_list = shm_recv_tensor_list_impl( + SHMManager::get_singleton_instance(handle)->get_shm_ctx(), src); + CPU_KERNEL_GUARD_OUT(shm_recv_tensor_list) + return tensor_list; +} + +int64_t init_shm_manager(const std::string& name, const int64_t group_size, + const int64_t rank) { + return SHMManager::create_singleton_instance(name, group_size, rank); +} + +std::string join_shm_manager(int64_t handle, const std::string& name) { + auto shm_manager = SHMManager::get_singleton_instance(handle); + TORCH_CHECK(shm_manager); + shm_manager->join(name); + return shm_manager->get_shm_ctx()->to_string(); +} \ No newline at end of file diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index ef5a2fb5..7ae7e338 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -22,6 +22,26 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); +int64_t init_shm_manager(const std::string& name, const int64_t group_size, + const int64_t rank); + +std::string join_shm_manager(int64_t handle, const std::string& name); + +void shm_allreduce(int64_t handle, torch::Tensor& data); + +void shm_gather(int64_t handle, torch::Tensor& data, + const std::optional>& outputs, + int64_t dst); + +void shm_all_gather(int64_t handle, const torch::Tensor& data, + torch::Tensor& output); + +void shm_send_tensor_list(int64_t handle, + const std::vector& tensor_list, + int64_t dst); + +std::vector shm_recv_tensor_list(int64_t handle, int64_t src); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -131,6 +151,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor? azp, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif + +// SHM CCL +#ifdef __AVX512F__ + ops.def("init_shm_manager(str name, int group_size, int rank) -> int", + &init_shm_manager); + ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager); + ops.def("shm_allreduce(int handle, Tensor! data) -> ()"); + ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce); + ops.def( + "shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> " + "()"); + ops.impl("shm_gather", torch::kCPU, &shm_gather); + ops.def( + "shm_all_gather(int handle, Tensor data, Tensor! output) -> " + "()"); + ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather); + ops.def( + "shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> " + "()"); + ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list); + ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)", + &shm_recv_tensor_list); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 42a1c1d9..79771ecd 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -18,7 +18,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { #ifndef VLLM_NUMA_DISABLED std::string init_cpu_threads_env(const std::string& cpu_ids) { - bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); + bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str()); TORCH_CHECK(omp_cpu_mask->size > 0); std::vector omp_cpu_ids; omp_cpu_ids.reserve(omp_cpu_mask->size); diff --git a/docs/source/getting_started/installation/cpu.md b/docs/source/getting_started/installation/cpu.md index e7e12bd6..db22ef79 100644 --- a/docs/source/getting_started/installation/cpu.md +++ b/docs/source/getting_started/installation/cpu.md @@ -272,12 +272,14 @@ $ python examples/offline_inference/basic/basic.py - Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance. -- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, two optimizations are to recommended: Tensor Parallel or Data Parallel. +- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance. - - Using Tensor Parallel for a latency constraints deployment: following GPU backend design, a Megatron-LM's parallel algorithm will be used to shard the model, based on the number of NUMA nodes (e.g. TP = 2 for a two NUMA node system). With [TP feature on CPU](gh-pr:6125) merged, Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving: + - Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving: ```console VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp ``` - - Using Data Parallel for maximum throughput: to launch an LLM serving endpoint on each NUMA node along with one additional load balancer to dispatch the requests to those endpoints. Common solutions like [Nginx](#nginxloadbalancer) or HAProxy are recommended. Anyscale Ray project provides the feature on LLM [serving](https://docs.ray.io/en/latest/serve/index.html). Here is the example to setup a scalable LLM serving with [Ray Serve](https://github.com/intel/llm-on-ray/blob/main/docs/setup.inc.md). + - For each thread id list in `VLLM_CPU_OMP_THREADS_BIND`, users should guarantee threads in the list belong to a same NUMA node. + + - Meanwhile, users should also take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, TP worker will be killed due to out-of-memory. diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index b920cd7e..1f4b4faf 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +import os +from typing import List, Optional import torch from torch.distributed import ProcessGroup +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum + from .base_device_communicator import DeviceCommunicatorBase @@ -16,19 +20,120 @@ class CpuCommunicator(DeviceCommunicatorBase): device_group: Optional[ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) - self.ipex_available = False self.dist_module = torch.distributed - try: - import intel_extension_for_pytorch as ipex - self.ipex_available = True - self.dist_module = ipex.distributed - except ImportError: - """ - Intel IPEX not found. Falling back to PyTorch native - all_reduce for CPU (e.g. MacOS) - """ - pass + + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: + self.dist_module = _CPUSHMDistributed(self) def all_reduce(self, input_): self.dist_module.all_reduce(input_, group=self.device_group) return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + + # Gather. + self.dist_module.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + self.dist_module.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + +class _CPUSHMDistributed: + + def __init__(self, communicator: CpuCommunicator): + instance_identifier = os.environ["VLLM_DIST_IDENT"] + self.communicator = communicator + + group_ranks = [str(rank) for rank in self.communicator.ranks] + shm_group_identifier = f"[{'-'.join(group_ranks)}]" + self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm" + + self.handle = self._init_cpu_shm() + + def _init_cpu_shm(self) -> int: + handle = torch.ops._C.init_shm_manager( + self.group_name, + self.communicator.world_size, + self.communicator.rank, + ) + torch.distributed.barrier(self.communicator.device_group) + torch.ops._C.join_shm_manager( + handle, + self.group_name, + ) + torch.distributed.barrier(self.communicator.device_group) + + return handle + + def all_reduce(self, + input: torch.Tensor, + group: Optional[ProcessGroup] = None) -> None: + torch.ops._C.shm_allreduce(self.handle, input) + + def gather(self, + input: torch.Tensor, + gather_list: Optional[List[torch.Tensor]], + dst: int = -1, + group: Optional[ProcessGroup] = None) -> None: + # Note: different from the torch gather, here we use local dst rank. + torch.ops._C.shm_gather(self.handle, input, gather_list, + torch.distributed.get_group_rank(group, dst)) + + def all_gather_into_tensor(self, + output: torch.Tensor, + input: torch.Tensor, + group: Optional[ProcessGroup] = None) -> None: + torch.ops._C.shm_all_gather(self.handle, input, output) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 5f35c1af..1436a404 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """A CPU worker class.""" +import os from typing import Dict, List, Optional, Set, Tuple, Type import torch @@ -139,6 +140,8 @@ class CPUWorker(LocalOrDistributedWorkerBase): self.local_rank = local_rank self.rank = rank + vllm_config.parallel_config.rank = rank + self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker @@ -217,6 +220,10 @@ class CPUWorker(LocalOrDistributedWorkerBase): ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) if ret: logger.info(ret) + + # Note: unique identifier for creating allreduce shared memory + os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split( + ":")[-1] self.device = torch.device("cpu") self.init_distributed_environment() # Set random seed.