[CPU][Bugfix] Using custom allreduce for CPU backend (#15934)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang 2025-04-02 22:46:47 +08:00 committed by GitHub
parent cefb9e5a28
commit 550b2801ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1013 additions and 16 deletions

View File

@ -197,6 +197,7 @@ set(VLLM_EXT_SRC
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/quant.cpp" "csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp"
${VLLM_EXT_SRC}) ${VLLM_EXT_SRC})
endif() endif()

View File

@ -78,9 +78,14 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
__m256i reg; __m256i reg;
// normal load
explicit FP16Vec16(const void* ptr) explicit FP16Vec16(const void* ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
// non-temproal load
explicit FP16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
explicit FP16Vec16(const FP32Vec16&); explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
@ -110,9 +115,14 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
__m256i reg; __m256i reg;
// normal load
explicit BF16Vec16(const void* ptr) explicit BF16Vec16(const void* ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
// non-temproal load
explicit BF16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
explicit BF16Vec16(const FP32Vec16&); explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
@ -313,8 +323,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
// normal load
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {} explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
// non-temproal load
explicit FP32Vec16(bool, void* ptr)
: reg((__m512)_mm512_stream_load_si512(ptr)) {}
explicit FP32Vec16(__m512 data) : reg(data) {} explicit FP32Vec16(__m512 data) : reg(data) {}
explicit FP32Vec16(const FP32Vec4& data) explicit FP32Vec16(const FP32Vec4& data)
@ -547,6 +562,33 @@ struct INT8Vec16 : public Vec<INT8Vec16> {
_mm_mask_storeu_epi8(ptr, mask, reg); _mm_mask_storeu_epi8(ptr, mask, reg);
} }
}; };
struct INT8Vec64 : public Vec<INT8Vec64> {
constexpr static int VEC_ELEM_NUM = 64;
union AliasReg {
__m512i reg;
int8_t values[VEC_ELEM_NUM];
};
__m512i reg;
// normal load
explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {}
// non-temproal load
explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {}
void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); }
void save(int8_t* ptr, const int elem_num) const {
constexpr uint64_t M = 0xFFFFFFFFFFFFFFFF;
__mmask64 mask = _cvtu64_mask64(M >> (64 - elem_num));
_mm512_mask_storeu_epi8(ptr, mask, reg);
}
// non-temproal save
void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); }
};
#endif #endif
template <typename T> template <typename T>
@ -657,6 +699,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); } inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
#ifdef __AVX512F__
inline void non_temporal_save(FP16Vec16& vec, void* ptr) {
_mm256_stream_si256((__m256i*)ptr, vec.reg);
}
inline void non_temporal_save(BF16Vec32& vec, void* ptr) {
_mm512_stream_si512((__m512i*)ptr, vec.reg);
}
inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
_mm256_stream_si256((__m256i*)ptr, vec.reg);
}
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
_mm512_stream_ps((float*)ptr, vec.reg);
}
#endif
inline void mem_barrier() { _mm_mfence(); }
}; // namespace vec_op }; // namespace vec_op
#endif #endif

781
csrc/cpu/shm.cpp Normal file
View File

@ -0,0 +1,781 @@
#include "cpu/cpu_types.hpp"
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
namespace {
#define MAX_SHM_RANK_NUM 8
#define MAX_THREAD_NUM 12
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
#define MIN_THREAD_PROCESS_SIZE (8 * 1024)
#define MAX_P2P_SEND_TENSOR_NUM 8
template <typename scalar_t>
struct KernelVecType {
using scalar_vec_t = void;
};
template <>
struct KernelVecType<float> {
using scalar_vec_t = vec_op::FP32Vec16;
};
template <>
struct KernelVecType<c10::BFloat16> {
using scalar_vec_t = vec_op::BF16Vec16;
};
template <>
struct KernelVecType<c10::Half> {
using scalar_vec_t = vec_op::FP16Vec16;
};
enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE };
struct ThreadSHMContext {
volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM];
int thread_id;
int thread_num;
int rank;
int group_size;
size_t _spinning_count;
int swizzled_ranks[MAX_SHM_RANK_NUM];
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
const int group_size, void* thread_shm_ptr)
: thread_id(thread_id),
thread_num(thread_num),
rank(rank),
group_size(group_size),
_spinning_count(0) {
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
TORCH_CHECK((size_t)this % 64 == 0);
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
shm_contexts[i] = nullptr;
thread_shm_ptrs[i] = nullptr;
swizzled_ranks[i] = (i + rank) % group_size;
thread_stats[i] = ThreadSHMStat::DONE;
}
set_context(rank, this, thread_shm_ptr);
}
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
TORCH_CHECK(ptr);
TORCH_CHECK(thread_shm_ptr);
TORCH_CHECK_EQ(ptr->thread_num, thread_num);
TORCH_CHECK_EQ(ptr->thread_id, thread_id);
shm_contexts[rank] = ptr;
thread_shm_ptrs[rank] = thread_shm_ptr;
}
template <typename T>
T* get_thread_shm_ptr(int rank) {
return reinterpret_cast<T*>(thread_shm_ptrs[rank]);
}
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
void wait_for_all(ThreadSHMStat prev_stat) {
for (int idx = 0; idx < group_size; ++idx) {
int rank = get_swizzled_rank(idx);
while (thread_stats[rank] == prev_stat) {
++_spinning_count;
_mm_pause();
}
}
vec_op::mem_barrier();
}
void wait_for_one(int rank, ThreadSHMStat prev_stat) {
while (thread_stats[rank] == prev_stat) {
++_spinning_count;
_mm_pause();
}
vec_op::mem_barrier();
}
void set_thread_stat(ThreadSHMStat stat) {
for (int idx = 0; idx < group_size; ++idx) {
int rank = get_swizzled_rank(idx);
shm_contexts[rank]->thread_stats[this->rank] = stat;
}
}
void set_thread_stat(int target_rank, ThreadSHMStat stat) {
for (int idx = 0; idx < group_size; ++idx) {
int rank = get_swizzled_rank(idx);
shm_contexts[rank]->thread_stats[target_rank] = stat;
}
}
// barrier for all ranks in the group, used for all2all ops
// DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ...
void barrier(ThreadSHMStat next_stat) {
if (next_stat == ThreadSHMStat::THREAD_READY) {
set_thread_stat(ThreadSHMStat::THREAD_READY);
wait_for_all(ThreadSHMStat::DONE);
} else if (next_stat == ThreadSHMStat::SHM_DATA_READY) {
set_thread_stat(ThreadSHMStat::SHM_DATA_READY);
wait_for_all(ThreadSHMStat::THREAD_READY);
} else if (next_stat == ThreadSHMStat::DONE) {
set_thread_stat(ThreadSHMStat::DONE);
wait_for_all(ThreadSHMStat::SHM_DATA_READY);
} else {
TORCH_CHECK(false, "Invalid next_stat to barrier.");
}
}
std::string to_string() const {
std::stringstream ss;
ss << "SHMContext:";
ss << "\nrank: " << rank;
ss << "\ngroup_size: " << group_size;
ss << "\nthread_num: " << thread_num;
ss << "\nthread_id: " << thread_id;
ss << "\nshm_ctx_stat_loop_seq: [";
for (int i = 0; i < group_size; ++i) {
ss << swizzled_ranks[i] << ", ";
}
ss << "]";
ss << "\nshm_contexts: [";
for (int i = 0; i < group_size; ++i) {
if (shm_contexts[i]) {
ss << shm_contexts[i]->rank << ", ";
}
}
ss << "]";
return ss.str();
}
};
class SHMManager {
public:
explicit SHMManager(const std::string& name, const int rank,
const int group_size)
: _rank(rank),
_group_size(group_size),
_thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)),
_shm_names({""}),
_shared_mem_ptrs({nullptr}),
_shm_ctx(nullptr) {
_shm_names[rank] = get_shm_name(name, rank);
_shared_mem_ptrs[rank] = init_shm(rank);
_shm_ctx = reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank]);
for (int i = 0; i < _thread_num; ++i) {
ThreadSHMContext* ctx = new (_shm_ctx + i)
ThreadSHMContext(i, _thread_num, _rank, _group_size,
compute_thread_shm_ptr(_shm_ctx, i));
}
}
void join(const std::string& name) {
for (int rank_idx = 0; rank_idx < _group_size; ++rank_idx) {
if (rank_idx != _rank) {
TORCH_CHECK(_shm_names[rank_idx].empty());
TORCH_CHECK(_shared_mem_ptrs[rank_idx] == nullptr);
_shm_names[rank_idx] = get_shm_name(name, rank_idx);
_shared_mem_ptrs[rank_idx] = init_shm(rank_idx);
ThreadSHMContext* target_ctx =
reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank_idx]);
for (int thread_idx = 0; thread_idx < _thread_num; ++thread_idx) {
_shm_ctx[thread_idx].set_context(
rank_idx, target_ctx + thread_idx,
compute_thread_shm_ptr(target_ctx, thread_idx));
}
}
}
}
~SHMManager() { destroy_shm(); }
ThreadSHMContext* get_shm_ctx() const { return _shm_ctx; }
static std::string get_shm_name(const std::string& name, int rank) {
return name + "_" + std::to_string(rank);
}
static int64_t create_singleton_instance(const std::string& name,
const int group_size,
const int rank) {
std::lock_guard<std::mutex> guard(SingletonInstancesLock);
SingletonInstances.emplace_back(
std::make_unique<SHMManager>(name, rank, group_size));
return static_cast<int64_t>(SingletonInstances.size() - 1);
}
static SHMManager* get_singleton_instance(int64_t handle) {
return SingletonInstances[handle].get();
}
protected:
static std::vector<std::unique_ptr<SHMManager>> SingletonInstances;
static std::mutex SingletonInstancesLock;
private:
static size_t round_to_alignment(size_t num) {
return ((num + 63) / 64) * 64;
}
int8_t* compute_thread_shm_ptr(ThreadSHMContext* ctx, int thread_id) {
int8_t* thread_shm_ptr =
reinterpret_cast<int8_t*>(ctx) +
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
return thread_shm_ptr +
thread_id * round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES);
}
size_t compute_shm_size() {
const size_t rounded_rank_buffer_size =
round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES) * _thread_num;
const size_t rounded_thread_shm_ctx_size =
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
const size_t shm_size =
rounded_thread_shm_ctx_size + rounded_rank_buffer_size;
return shm_size;
}
void* init_shm(int target_rank) {
const std::string& shm_name = _shm_names[target_rank];
const int local_rank = _rank;
const size_t shm_size = compute_shm_size();
int fd = -1;
if (local_rank == target_rank) {
fd = shm_open(shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR,
S_IRUSR | S_IWUSR);
if (fd == -1)
TORCH_CHECK(false, "create shm in SHMManager failed. errno: " +
std::to_string(errno));
if (ftruncate(fd, shm_size) == -1)
TORCH_CHECK(false, "ftruncate in SHMManager failed. errno: " +
std::to_string(errno));
} else {
fd = shm_open(shm_name.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
if (fd == -1)
TORCH_CHECK(false, "open shm in SHMManager failed. errno: " +
std::to_string(errno));
}
void* shm_ptr = mmap(nullptr, shm_size, PROT_READ | PROT_WRITE,
MAP_SHARED | MAP_POPULATE, fd, 0);
if (shm_ptr == MAP_FAILED) {
TORCH_CHECK(false,
"mmap in SHMManager failed. errno: " + std::to_string(errno));
}
if (close(fd) != 0) {
TORCH_CHECK(
false, "close in SHMManager failed. errno: " + std::to_string(errno));
}
TORCH_CHECK((size_t)shm_ptr % 64 == 0);
return shm_ptr;
}
void destroy_shm() {
std::stringstream ss;
ss << "local rank " << _rank << ": [";
for (int thread_id = 0; thread_id < _thread_num; ++thread_id) {
ss << _shm_ctx[thread_id]._spinning_count << ", ";
}
ss << "]\n";
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
if (_shared_mem_ptrs[i] != nullptr) {
munmap(_shared_mem_ptrs[i], compute_shm_size());
}
if (!_shm_names[i].empty()) {
shm_unlink(_shm_names[i].c_str());
}
}
}
int _rank;
int _group_size;
int _thread_num;
std::array<std::string, MAX_SHM_RANK_NUM> _shm_names;
std::array<void*, MAX_SHM_RANK_NUM> _shared_mem_ptrs;
ThreadSHMContext* _shm_ctx;
};
namespace shm_cc_ops {
template <typename scalar_t, typename F>
void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
int thread_num = ctx->thread_num;
int64_t total_bytes = elem_num * sizeof(scalar_t);
int64_t total_units_num =
(total_bytes + MIN_THREAD_PROCESS_SIZE - 1) / MIN_THREAD_PROCESS_SIZE;
int64_t per_thread_units_num =
(total_units_num + thread_num - 1) / thread_num;
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
int64_t max_per_thread_iteration_elem_num =
PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t);
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
#pragma omp parallel for schedule(static, 1)
for (int i = 0; i < thread_num; ++i) {
int64_t offset = i * per_thread_elem_num;
int64_t end = std::min(elem_num, offset + per_thread_elem_num);
int64_t curr_elem_num =
std::min(max_per_thread_iteration_elem_num, end - offset);
ThreadSHMContext* thread_ctx = ctx + i;
while (curr_elem_num > 0) {
inner_func(thread_ctx, offset, curr_elem_num);
offset += max_per_thread_iteration_elem_num;
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
}
}
}
}; // namespace shm_cc_ops
namespace shm_cc_ops {
void memcpy_from_shm(void* dst, void* src, const int64_t bytes) {
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
int64_t i = 0;
#pragma GCC unroll 4
for (; i < aligned_bytes; i += 64) {
vec_op::INT8Vec64 data(
true, (int8_t*)src + i); // stream loading shm to avoid caching
data.save((int8_t*)dst + i);
}
if (aligned_bytes < bytes) {
vec_op::INT8Vec64 data(true, (int8_t*)src + aligned_bytes);
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
}
}
void memcpy_to_shm(void* dst, void* src, const int64_t bytes) {
#pragma GCC unroll 4
for (int64_t i = 0; i < bytes; i += 64) {
vec_op::INT8Vec64 data((int8_t*)src + i);
data.nt_save((int8_t*)dst + i);
}
}
void memcpy(void* dst, void* src, const int64_t bytes) {
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
int64_t i = 0;
#pragma GCC unroll 4
for (; i < aligned_bytes; i += 64) {
vec_op::INT8Vec64 data((int8_t*)src + i);
data.save((int8_t*)dst + i);
}
if (aligned_bytes < bytes) {
vec_op::INT8Vec64 data((int8_t*)src + aligned_bytes);
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
}
}
template <typename scalar_t, int RANKS>
void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
size_t elem_num) {
CPU_KERNEL_GUARD_IN(all_reduce_sum_impl)
using vec_t = typename KernelVecType<scalar_t>::scalar_vec_t;
constexpr int64_t vec_elem_num = vec_t::get_elem_num();
const int worldsize = ctx->group_size;
shm_cc_ops::shm_cc_loop<scalar_t>(
ctx, elem_num,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
int64_t data_elem_num) {
int rank = thread_ctx->rank;
scalar_t* thread_shm_ptr =
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
scalar_t* thread_data_ptr = data + data_offset;
int64_t thread_data_elem_num = data_elem_num * sizeof(scalar_t);
scalar_t* remote_data_ptrs[RANKS - 1];
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
remote_data_ptrs[idx] = thread_ctx->get_thread_shm_ptr<scalar_t>(
thread_ctx->get_swizzled_rank(idx + 1));
});
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
thread_data_elem_num);
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
int64_t aligned_data_elem_num =
(data_elem_num / vec_elem_num) * vec_elem_num;
int64_t i = 0;
#pragma GCC unroll 4
for (; i < aligned_data_elem_num; i += vec_elem_num) {
vec_t local_data(thread_data_ptr + i); // load from cache
vec_op::FP32Vec16 local_data_fp32(local_data);
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
vec_t remote_data(
true, remote_data_ptrs[idx] + i); // stream load from shm
vec_op::FP32Vec16 remote_data_fp32(remote_data);
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
});
vec_t reduced_data(local_data_fp32);
reduced_data.save(thread_data_ptr + i);
}
if (i < data_elem_num) {
vec_t local_data(thread_data_ptr + i); // load from cache
vec_op::FP32Vec16 local_data_fp32(local_data);
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
vec_t remote_data(
true, remote_data_ptrs[idx] + i); // stream load from shm
vec_op::FP32Vec16 remote_data_fp32(remote_data);
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
});
vec_t reduced_data(local_data_fp32);
reduced_data.save(thread_data_ptr + i,
data_elem_num - aligned_data_elem_num);
}
thread_ctx->barrier(ThreadSHMStat::DONE);
});
return;
}
}; // namespace shm_cc_ops
std::vector<std::unique_ptr<SHMManager>> SHMManager::SingletonInstances = {};
std::mutex SHMManager::SingletonInstancesLock = {};
template <typename scalar_t>
void shm_allreduce_sum(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num) {
switch (ctx->group_size) {
case 2:
shm_cc_ops::all_reduce_sum_impl<scalar_t, 2>(ctx, data, elem_num);
break;
case 3:
shm_cc_ops::all_reduce_sum_impl<scalar_t, 3>(ctx, data, elem_num);
break;
case 4:
shm_cc_ops::all_reduce_sum_impl<scalar_t, 4>(ctx, data, elem_num);
break;
case 8:
shm_cc_ops::all_reduce_sum_impl<scalar_t, 8>(ctx, data, elem_num);
break;
default:
TORCH_CHECK(false,
"Invalid world size: " + std::to_string(ctx->group_size));
}
}
template <typename scalar_t>
void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
scalar_t** outputs, const int dst) {
CPU_KERNEL_GUARD_IN(shm_gather_impl)
const int worldsize = ctx->group_size;
TORCH_CHECK_LT(dst, worldsize);
shm_cc_ops::shm_cc_loop<scalar_t>(
ctx, elem_num,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
int64_t data_elem_num) {
int rank = thread_ctx->rank;
scalar_t* thread_shm_ptr =
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset,
data_elem_num * sizeof(scalar_t));
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
if (rank == dst) {
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
data_elem_num * sizeof(scalar_t));
for (int i = 1; i < worldsize; ++i) {
int src_rank = thread_ctx->get_swizzled_rank(i);
scalar_t* src_ptr =
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr,
data_elem_num * sizeof(scalar_t));
}
}
thread_ctx->barrier(ThreadSHMStat::DONE);
});
return;
}
struct MemPiece {
void* ptr;
int64_t size;
template <typename T>
T* data_ptr() {
return reinterpret_cast<T*>(ptr);
}
};
struct TensorListMeta {
int64_t tensor_bytes[MAX_P2P_SEND_TENSOR_NUM];
torch::ScalarType tensor_types[MAX_P2P_SEND_TENSOR_NUM];
int64_t tensor_num;
int64_t total_bytes;
TensorListMeta() : tensor_num(0), total_bytes(0) {
static_assert(sizeof(TensorListMeta) % 64 == 0);
static_assert(sizeof(TensorListMeta) <
MIN_THREAD_PROCESS_SIZE); // To ensure the metadata always
// hold by the thread 0
for (int i = 0; i < MAX_P2P_SEND_TENSOR_NUM; ++i) {
tensor_bytes[i] = 0;
tensor_ptrs[i] = nullptr;
tensor_types[i] = torch::ScalarType::Undefined;
}
}
// For send and recv
void bind_tensor_list(std::vector<torch::Tensor>& tensor_list) {
TORCH_CHECK(tensor_types[0] == torch::ScalarType::Undefined,
"Re-bind TensorListMeta is not allowed.")
TORCH_CHECK_LE(tensor_list.size(), MAX_P2P_SEND_TENSOR_NUM);
tensor_num = tensor_list.size();
int64_t bytes_sum = 0;
for (int i = 0; i < tensor_list.size(); ++i) {
torch::Tensor& t = tensor_list[i];
TORCH_CHECK(t.is_contiguous());
tensor_bytes[i] = t.nbytes();
tensor_types[i] = t.scalar_type();
tensor_ptrs[i] = t.data_ptr();
bytes_sum += t.nbytes();
}
total_bytes = bytes_sum;
}
// For recv
std::vector<torch::Tensor> generate_tensor_list() {
std::vector<torch::Tensor> tensor_list;
tensor_list.reserve(tensor_num);
for (int i = 0; i < tensor_num; ++i) {
int64_t bytes = tensor_bytes[i];
auto type = tensor_types[i];
int64_t elem_bytes = torch::elementSize(type);
TORCH_CHECK_EQ(bytes % elem_bytes, 0);
int64_t elem_num = bytes / elem_bytes;
auto options = torch::TensorOptions().dtype(type).device(torch::kCPU);
tensor_list.emplace_back(torch::empty({elem_num}, options));
}
return tensor_list;
}
MemPiece get_data(int64_t offset) {
for (int i = 0; i < tensor_num; ++i) {
if (offset < tensor_bytes[i]) {
return {reinterpret_cast<int8_t*>(tensor_ptrs[i]) + offset,
tensor_bytes[i] - offset};
}
offset -= tensor_bytes[i];
}
return {nullptr, 0};
}
private:
void* tensor_ptrs[MAX_P2P_SEND_TENSOR_NUM];
int8_t _padding[40];
};
void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
const std::vector<torch::Tensor>& tensor_list) {
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
std::vector<torch::Tensor> tensor_list_with_metadata;
tensor_list_with_metadata.reserve(1 + tensor_list.size());
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
tensor_list_with_metadata.emplace_back(
torch::empty({sizeof(TensorListMeta)}, options));
tensor_list_with_metadata.insert(tensor_list_with_metadata.end(),
tensor_list.begin(), tensor_list.end());
torch::Tensor& metadata_tensor = tensor_list_with_metadata[0];
TORCH_CHECK_EQ(metadata_tensor.nbytes(), sizeof(TensorListMeta));
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
metadata->bind_tensor_list(tensor_list_with_metadata);
shm_cc_ops::shm_cc_loop<int8_t>(
ctx, metadata->total_bytes,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
int64_t data_elem_num) {
int rank = thread_ctx->rank;
// Wait until the receiver set the stat to DONE
thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY);
int64_t curr_shm_offset = 0;
while (curr_shm_offset < data_elem_num) {
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
shm_cc_ops::memcpy(
thread_ctx->get_thread_shm_ptr<int8_t>(rank) + curr_shm_offset,
frag.ptr, frag.size);
curr_shm_offset += frag.size;
}
thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY);
});
}
std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
int64_t src) {
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list_impl)
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
torch::Tensor metadata_tensor =
torch::empty({sizeof(TensorListMeta)}, options);
// Wait until the sender set the stat of the thread 0 to SHM_DATA_READY
ctx->wait_for_one(src, ThreadSHMStat::DONE);
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
ctx->get_thread_shm_ptr<void>(src),
sizeof(TensorListMeta));
TensorListMeta* src_metadata =
reinterpret_cast<TensorListMeta*>(metadata_tensor.data_ptr());
std::vector<torch::Tensor> tensor_list_with_metadata =
src_metadata->generate_tensor_list();
TensorListMeta metadata;
metadata.bind_tensor_list(tensor_list_with_metadata);
TORCH_CHECK_EQ(metadata.tensor_num, src_metadata->tensor_num);
TORCH_CHECK_EQ(metadata.total_bytes, src_metadata->total_bytes);
shm_cc_ops::shm_cc_loop<int8_t>(
ctx, metadata.total_bytes,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
int64_t data_elem_num) {
// Wait until the sender set the stat to SHM_DATA_READY
thread_ctx->wait_for_one(src, ThreadSHMStat::DONE);
int64_t curr_shm_offset = 0;
while (curr_shm_offset < data_elem_num) {
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
shm_cc_ops::memcpy(
frag.ptr,
thread_ctx->get_thread_shm_ptr<int8_t>(src) + curr_shm_offset,
frag.size);
curr_shm_offset += frag.size;
}
thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE);
});
std::vector<torch::Tensor> tensor_list;
tensor_list.reserve(metadata.tensor_num - 1);
tensor_list.insert(tensor_list.begin(), tensor_list_with_metadata.begin() + 1,
tensor_list_with_metadata.end());
return tensor_list;
}
} // namespace
void shm_gather(int64_t handle, torch::Tensor& data,
const std::optional<std::vector<torch::Tensor>>& outputs,
int64_t dst) {
TORCH_CHECK(data.is_contiguous())
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_gather_impl", [&] {
CPU_KERNEL_GUARD_IN(shm_gather_impl)
if (outputs.has_value()) {
TORCH_CHECK_LE(outputs->size(), MAX_SHM_RANK_NUM);
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
for (int i = 0; i < outputs->size(); ++i) {
output_ptrs[i] = outputs->at(i).data_ptr<scalar_t>();
}
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
dst);
} else {
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
data.data_ptr<scalar_t>(), data.numel(), (scalar_t**)(0),
dst);
}
CPU_KERNEL_GUARD_OUT(shm_gather_impl)
});
}
void shm_all_gather(int64_t handle, const torch::Tensor& data,
torch::Tensor& output) {
TORCH_CHECK(data.is_contiguous())
TORCH_CHECK(output.is_contiguous())
const int64_t input_elem_num = data.numel();
const int64_t output_elem_num = output.numel();
TORCH_CHECK_EQ(output_elem_num % input_elem_num, 0);
const int world_size = output_elem_num / input_elem_num;
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_all_gather_impl", [&] {
CPU_KERNEL_GUARD_IN(shm_all_gather_impl)
auto ctx = SHMManager::get_singleton_instance(handle)->get_shm_ctx();
TORCH_CHECK_EQ(ctx->group_size, world_size);
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
for (int i = 0; i < world_size; ++i) {
output_ptrs[i] = output.data_ptr<scalar_t>() + i * input_elem_num;
}
shm_gather_impl(ctx, data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
ctx->rank);
CPU_KERNEL_GUARD_OUT(shm_all_gather_impl)
});
}
void shm_allreduce(int64_t handle, torch::Tensor& data) {
TORCH_CHECK(data.is_contiguous())
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_allreduce_sum", [&] {
CPU_KERNEL_GUARD_IN(shm_allreduce_sum)
shm_allreduce_sum(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
data.data_ptr<scalar_t>(), data.numel());
CPU_KERNEL_GUARD_OUT(shm_allreduce_sum)
});
}
void shm_send_tensor_list(int64_t handle,
const std::vector<torch::Tensor>& tensor_list,
int64_t dst) {
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
shm_send_tensor_list_impl(
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list);
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
}
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src) {
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list)
auto tensor_list = shm_recv_tensor_list_impl(
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), src);
CPU_KERNEL_GUARD_OUT(shm_recv_tensor_list)
return tensor_list;
}
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank) {
return SHMManager::create_singleton_instance(name, group_size, rank);
}
std::string join_shm_manager(int64_t handle, const std::string& name) {
auto shm_manager = SHMManager::get_singleton_instance(handle);
TORCH_CHECK(shm_manager);
shm_manager->join(name);
return shm_manager->get_shm_ctx()->to_string();
}

View File

@ -22,6 +22,26 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale, torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens); torch::Tensor& block_tables, torch::Tensor& seq_lens);
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank);
std::string join_shm_manager(int64_t handle, const std::string& name);
void shm_allreduce(int64_t handle, torch::Tensor& data);
void shm_gather(int64_t handle, torch::Tensor& data,
const std::optional<std::vector<torch::Tensor>>& outputs,
int64_t dst);
void shm_all_gather(int64_t handle, const torch::Tensor& data,
torch::Tensor& output);
void shm_send_tensor_list(int64_t handle,
const std::vector<torch::Tensor>& tensor_list,
int64_t dst);
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
@ -131,6 +151,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor? azp, Tensor? bias) -> ()"); " Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#endif #endif
// SHM CCL
#ifdef __AVX512F__
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
&init_shm_manager);
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
ops.def(
"shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
"()");
ops.impl("shm_gather", torch::kCPU, &shm_gather);
ops.def(
"shm_all_gather(int handle, Tensor data, Tensor! output) -> "
"()");
ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather);
ops.def(
"shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
"()");
ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
&shm_recv_tensor_list);
#endif
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

View File

@ -18,7 +18,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
#ifndef VLLM_NUMA_DISABLED #ifndef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) { std::string init_cpu_threads_env(const std::string& cpu_ids) {
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str());
TORCH_CHECK(omp_cpu_mask->size > 0); TORCH_CHECK(omp_cpu_mask->size > 0);
std::vector<int> omp_cpu_ids; std::vector<int> omp_cpu_ids;
omp_cpu_ids.reserve(omp_cpu_mask->size); omp_cpu_ids.reserve(omp_cpu_mask->size);

View File

@ -272,12 +272,14 @@ $ python examples/offline_inference/basic/basic.py
- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance. - Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance.
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, two optimizations are to recommended: Tensor Parallel or Data Parallel. - On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance.
- Using Tensor Parallel for a latency constraints deployment: following GPU backend design, a Megatron-LM's parallel algorithm will be used to shard the model, based on the number of NUMA nodes (e.g. TP = 2 for a two NUMA node system). With [TP feature on CPU](gh-pr:6125) merged, Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving: - Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
```console ```console
VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp
``` ```
- Using Data Parallel for maximum throughput: to launch an LLM serving endpoint on each NUMA node along with one additional load balancer to dispatch the requests to those endpoints. Common solutions like [Nginx](#nginxloadbalancer) or HAProxy are recommended. Anyscale Ray project provides the feature on LLM [serving](https://docs.ray.io/en/latest/serve/index.html). Here is the example to setup a scalable LLM serving with [Ray Serve](https://github.com/intel/llm-on-ray/blob/main/docs/setup.inc.md). - For each thread id list in `VLLM_CPU_OMP_THREADS_BIND`, users should guarantee threads in the list belong to a same NUMA node.
- Meanwhile, users should also take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, TP worker will be killed due to out-of-memory.

View File

@ -1,10 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional import os
from typing import List, Optional
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
@ -16,19 +20,120 @@ class CpuCommunicator(DeviceCommunicatorBase):
device_group: Optional[ProcessGroup] = None, device_group: Optional[ProcessGroup] = None,
unique_name: str = ""): unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(cpu_group, device, device_group, unique_name)
self.ipex_available = False
self.dist_module = torch.distributed self.dist_module = torch.distributed
try:
import intel_extension_for_pytorch as ipex if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
self.ipex_available = True self.dist_module = _CPUSHMDistributed(self)
self.dist_module = ipex.distributed
except ImportError:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU (e.g. MacOS)
"""
pass
def all_reduce(self, input_): def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group) self.dist_module.all_reduce(input_, group=self.device_group)
return input_ return input_
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
self.dist_module.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
self.dist_module.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
instance_identifier = os.environ["VLLM_DIST_IDENT"]
self.communicator = communicator
group_ranks = [str(rank) for rank in self.communicator.ranks]
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
self.handle = self._init_cpu_shm()
def _init_cpu_shm(self) -> int:
handle = torch.ops._C.init_shm_manager(
self.group_name,
self.communicator.world_size,
self.communicator.rank,
)
torch.distributed.barrier(self.communicator.device_group)
torch.ops._C.join_shm_manager(
handle,
self.group_name,
)
torch.distributed.barrier(self.communicator.device_group)
return handle
def all_reduce(self,
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
torch.ops._C.shm_allreduce(self.handle, input)
def gather(self,
input: torch.Tensor,
gather_list: Optional[List[torch.Tensor]],
dst: int = -1,
group: Optional[ProcessGroup] = None) -> None:
# Note: different from the torch gather, here we use local dst rank.
torch.ops._C.shm_gather(self.handle, input, gather_list,
torch.distributed.get_group_rank(group, dst))
def all_gather_into_tensor(self,
output: torch.Tensor,
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
torch.ops._C.shm_all_gather(self.handle, input, output)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""A CPU worker class.""" """A CPU worker class."""
import os
from typing import Dict, List, Optional, Set, Tuple, Type from typing import Dict, List, Optional, Set, Tuple, Type
import torch import torch
@ -139,6 +140,8 @@ class CPUWorker(LocalOrDistributedWorkerBase):
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
vllm_config.parallel_config.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
@ -217,6 +220,10 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret: if ret:
logger.info(ret) logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
":")[-1]
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.init_distributed_environment() self.init_distributed_environment()
# Set random seed. # Set random seed.