[Hardware][Intel] Support CPU inference with AVX2 ISA (#5452)
This commit is contained in:
parent
50eed24d25
commit
cd9c0d65d9
@ -33,6 +33,7 @@ function (find_isa CPUINFO TARGET OUT)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
|
||||
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
|
||||
|
||||
if (AVX512_FOUND)
|
||||
@ -53,8 +54,11 @@ if (AVX512_FOUND)
|
||||
else()
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
|
||||
endif()
|
||||
elseif (AVX2_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
||||
message(WARNING "vLLM CPU backend using AVX2 ISA")
|
||||
else()
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 ISA support.")
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
|
@ -5,6 +5,10 @@
|
||||
#include <immintrin.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifndef __AVX2__
|
||||
static_assert(false, "AVX2 must be supported for the current implementation.");
|
||||
#endif
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
@ -104,6 +108,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
constexpr static int VEC_ELEM_NUM = 32;
|
||||
|
||||
@ -123,6 +128,34 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
|
||||
};
|
||||
#else
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
constexpr static int VEC_ELEM_NUM = 32;
|
||||
|
||||
__m256i reg_low;
|
||||
__m256i reg_high;
|
||||
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
|
||||
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
|
||||
|
||||
explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
|
||||
reg_high(high) {}
|
||||
|
||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
||||
: reg_low((__m256i)_mm256_inserti32x4(
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)),
|
||||
reg_high((__m256i)_mm256_inserti32x4(
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
|
||||
void save(void *ptr) const {
|
||||
*reinterpret_cast<__m256i *>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
constexpr static int VEC_ELEM_NUM = 4;
|
||||
@ -226,6 +259,7 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
@ -290,6 +324,114 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
};
|
||||
#else
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
union AliasReg {
|
||||
__m256 reg;
|
||||
float values[8];
|
||||
};
|
||||
|
||||
__m256 reg_low;
|
||||
__m256 reg_high;
|
||||
|
||||
explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
|
||||
reg_high(_mm256_set1_ps(v)) {}
|
||||
|
||||
explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
|
||||
reg_high(_mm256_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
|
||||
reg_high(_mm256_loadu_ps(ptr + 8)) {}
|
||||
|
||||
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
|
||||
reg_high(data.reg_high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data)
|
||||
: reg_low((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg),
|
||||
(__m128i)data.reg, 1)),
|
||||
reg_high((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg),
|
||||
(__m128i)data.reg, 1)) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data)
|
||||
: reg_low(data.reg), reg_high(data.reg) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v) {
|
||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||
|
||||
__m256i v_low_epi32 = _mm256_cvtepu16_epi32(low);
|
||||
__m256i v_high_epi32 = _mm256_cvtepu16_epi32(high);
|
||||
|
||||
__m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2);
|
||||
__m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2);
|
||||
|
||||
reg_low = _mm256_castsi256_ps(v_low_shifted);
|
||||
reg_high = _mm256_castsi256_ps(v_high_shifted);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
|
||||
_mm256_mul_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
|
||||
_mm256_add_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
|
||||
_mm256_sub_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
|
||||
_mm256_div_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
FP32Vec8 low = FP32Vec8(reg_low);
|
||||
FP32Vec8 high = FP32Vec8(reg_high);
|
||||
return low.reduce_sum() + high.reduce_sum();
|
||||
}
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
float sum = 0.0;
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
uint32_t mask = base_mask << (idx * group_size);
|
||||
|
||||
AliasReg ar;
|
||||
|
||||
auto func = [&sum, &mask, &ar](int i) {
|
||||
int flag = mask & 0x1;
|
||||
mask = mask >> 1;
|
||||
if (flag != 0) sum += ar.values[i];
|
||||
};
|
||||
|
||||
ar.reg = reg_low;
|
||||
unroll_loop<int, 8>(func);
|
||||
|
||||
ar.reg = reg_high;
|
||||
unroll_loop<int, 8>(func);
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
_mm256_storeu_ps(ptr, reg_low);
|
||||
_mm256_storeu_ps(ptr + 8, reg_high);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
|
||||
@ -336,6 +478,7 @@ template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
: reg(_mm256_cvtepi32_epi16(
|
||||
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
|
||||
@ -343,7 +486,27 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||
: reg(_mm512_cvtepi32_epi16(
|
||||
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
|
||||
#endif
|
||||
#else
|
||||
namespace{
|
||||
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
||||
__m256i ai = _mm256_castps_si256(a);
|
||||
ai = _mm256_srli_epi32(ai, 16);
|
||||
ai = _mm256_packus_epi32(ai, ai);
|
||||
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
|
||||
return _mm256_extracti128_si256(ai, 0);
|
||||
}
|
||||
}
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
|
||||
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
|
||||
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BF16__
|
||||
|
||||
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user