[Hardware][CPU] Update torch 2.5 (#9911)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang 2024-11-07 12:43:08 +08:00 committed by GitHub
parent 29862b884b
commit a4b3e0c1e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 81 additions and 53 deletions

View File

@ -46,7 +46,7 @@ docker exec cpu-test bash -c "
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=48-92
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \

View File

@ -22,7 +22,7 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li
RUN echo 'ulimit -c 0' >> ~/.bashrc
RUN pip install intel_extension_for_pytorch==2.4.0
RUN pip install intel_extension_for_pytorch==2.5.0
WORKDIR /workspace

View File

@ -18,6 +18,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
#
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-mf16c"
"-DVLLM_CPU_EXTENSION")
execute_process(COMMAND cat /proc/cpuinfo

View File

@ -22,6 +22,16 @@ struct KernelVecType<float> {
using v_load_vec_type = vec_op::FP32Vec16;
};
template <>
struct KernelVecType<c10::Half> {
using q_load_vec_type = vec_op::FP16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP16Vec16;
};
#ifdef __AVX512BF16__
template <>
struct KernelVecType<c10::BFloat16> {

View File

@ -11,10 +11,10 @@ static_assert(false, "AVX2 must be supported for the current implementation.");
namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
@ -50,37 +50,37 @@ template <typename T> struct Vec {
struct FP32Vec8;
struct FP32Vec16;
#ifdef __AVX512FP16__
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128h reg;
__m128i reg;
explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
explicit FP16Vec8(const void *ptr)
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
explicit FP16Vec8(const FP32Vec8 &);
explicit FP16Vec8(__m128h data) : reg(data) {}
FP16Vec8 operator*(const FP16Vec8 &b) const {
return FP16Vec8(_mm_mul_ph(reg, b.reg));
}
FP16Vec8 operator+(const FP16Vec8 &b) const {
return FP16Vec8(_mm_add_ph(reg, b.reg));
}
FP16Vec8 operator-(const FP16Vec8 &b) const {
return FP16Vec8(_mm_sub_ph(reg, b.reg));
}
FP16Vec8 operator/(const FP16Vec8 &b) const {
return FP16Vec8(_mm_div_ph(reg, b.reg));
}
void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
__m256i reg;
explicit FP16Vec16(const void *ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
explicit FP16Vec16(const FP32Vec16 &);
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm256_mask_storeu_epi16(ptr, mask, reg);
}
};
#endif
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
@ -202,9 +202,7 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
#ifdef __AVX512FP16__
explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
#endif
explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {}
explicit FP32Vec8(const BF16Vec8 &v)
: reg(_mm256_castsi256_ps(
@ -323,6 +321,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
: reg(_mm512_castsi512_ps(
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {}
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16 &v)
@ -534,24 +536,34 @@ template <typename T> using vec_t = typename VecType<T>::vec_type;
template <> struct VecType<float> { using vec_type = FP32Vec8; };
#ifdef __AVX512FP16__
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
#endif
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
#ifdef __AVX512FP16__
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
*reinterpret_cast<_Float16 *>(ptr) = v;
}
#endif
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc = acc + a * b;
}
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
*reinterpret_cast<unsigned short *>(ptr) =
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
}
inline FP16Vec8::FP16Vec8(const FP32Vec8 &v)
: reg(_mm256_cvtps_ph(v.reg,
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
#ifdef __AVX512F__
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
: reg(_mm512_cvtps_ph(v.reg,
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
#else
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
: reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
#endif
#ifdef __AVX512BF16__
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);

View File

@ -2,6 +2,7 @@
#define DNNL_HELPER_HPP
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include "oneapi/dnnl/dnnl.hpp"
@ -32,6 +33,11 @@ struct DNNLType<c10::BFloat16> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
};
template <>
struct DNNLType<c10::Half> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
};
template <typename T>
constexpr inline dnnl::memory::data_type get_dnnl_type() {
return DNNLType<std::decay_t<T>>::type;

View File

@ -23,6 +23,13 @@ struct KernelVecType<c10::BFloat16> {
using cvt_vec_type = vec_op::FP32Vec16;
};
template <>
struct KernelVecType<c10::Half> {
using load_vec_type = vec_op::FP16Vec16;
using azp_adj_load_vec_type = vec_op::INT32Vec16;
using cvt_vec_type = vec_op::FP32Vec16;
};
#ifdef __AVX512F__
template <bool AZP, typename scalar_t>
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,

View File

@ -3,13 +3,13 @@
Installation with CPU
========================
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16. vLLM CPU backend supports the following vLLM features:
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
- Tensor Parallel (``-tp = N``)
- Quantization (``INT8 W8A8, AWQ``)
.. note::
FP16 data type and more advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.
More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.
Table of contents:
@ -72,8 +72,6 @@ Build from source
$ VLLM_TARGET_DEVICE=cpu python setup.py install
.. note::
- BF16 is the default data type in the current CPU backend (that means the backend will cast FP16 to BF16), and is compatible will all CPUs with AVX512 ISA support.
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.

View File

@ -2,5 +2,5 @@
-r requirements-common.txt
# Dependencies for x86_64 CPUs
torch == 2.4.0+cpu; platform_machine != "ppc64le"
torch == 2.5.1+cpu; platform_machine != "ppc64le"
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch

View File

@ -32,8 +32,7 @@ if not current_platform.is_cpu():
"openbmb/MiniCPM3-4B",
]
# TODO: remove this after CPU float16 support ready
target_dtype = "float" if current_platform.is_cpu() else "half"
target_dtype = "half"
@pytest.mark.parametrize("model", MODELS)

View File

@ -2,8 +2,6 @@ import os
from functools import partial
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
import torch
import vllm.envs as envs
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
@ -316,9 +314,6 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.float16:
logger.warning("float16 is not supported on CPU, casting to bfloat16.")
config.dtype = torch.bfloat16
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if not config.enforce_eager:

View File

@ -54,7 +54,7 @@ class IPEXConfig(QuantizationConfig):
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16]
return [torch.bfloat16, torch.float16]
@classmethod
def get_min_capability(cls) -> int: