[Hardware][CPU] Update torch 2.5 (#9911)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
29862b884b
commit
a4b3e0c1e9
@ -46,7 +46,7 @@ docker exec cpu-test bash -c "
|
|||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
export VLLM_CPU_KVCACHE_SPACE=10
|
export VLLM_CPU_KVCACHE_SPACE=10
|
||||||
export VLLM_CPU_OMP_THREADS_BIND=48-92
|
export VLLM_CPU_OMP_THREADS_BIND=48-92
|
||||||
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
|
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
|
||||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||||
python3 benchmarks/benchmark_serving.py \
|
python3 benchmarks/benchmark_serving.py \
|
||||||
--backend vllm \
|
--backend vllm \
|
||||||
|
@ -22,7 +22,7 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li
|
|||||||
|
|
||||||
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
||||||
|
|
||||||
RUN pip install intel_extension_for_pytorch==2.4.0
|
RUN pip install intel_extension_for_pytorch==2.5.0
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
|||||||
#
|
#
|
||||||
list(APPEND CXX_COMPILE_FLAGS
|
list(APPEND CXX_COMPILE_FLAGS
|
||||||
"-fopenmp"
|
"-fopenmp"
|
||||||
|
"-mf16c"
|
||||||
"-DVLLM_CPU_EXTENSION")
|
"-DVLLM_CPU_EXTENSION")
|
||||||
|
|
||||||
execute_process(COMMAND cat /proc/cpuinfo
|
execute_process(COMMAND cat /proc/cpuinfo
|
||||||
|
@ -22,6 +22,16 @@ struct KernelVecType<float> {
|
|||||||
using v_load_vec_type = vec_op::FP32Vec16;
|
using v_load_vec_type = vec_op::FP32Vec16;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<c10::Half> {
|
||||||
|
using q_load_vec_type = vec_op::FP16Vec8;
|
||||||
|
using q_vec_type = vec_op::FP32Vec16;
|
||||||
|
using k_load_vec_type = vec_op::FP16Vec16;
|
||||||
|
using k_vec_type = vec_op::FP32Vec16;
|
||||||
|
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||||
|
using v_load_vec_type = vec_op::FP16Vec16;
|
||||||
|
};
|
||||||
|
|
||||||
#ifdef __AVX512BF16__
|
#ifdef __AVX512BF16__
|
||||||
template <>
|
template <>
|
||||||
struct KernelVecType<c10::BFloat16> {
|
struct KernelVecType<c10::BFloat16> {
|
||||||
|
@ -11,10 +11,10 @@ static_assert(false, "AVX2 must be supported for the current implementation.");
|
|||||||
|
|
||||||
namespace vec_op {
|
namespace vec_op {
|
||||||
|
|
||||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
@ -50,37 +50,37 @@ template <typename T> struct Vec {
|
|||||||
struct FP32Vec8;
|
struct FP32Vec8;
|
||||||
struct FP32Vec16;
|
struct FP32Vec16;
|
||||||
|
|
||||||
#ifdef __AVX512FP16__
|
|
||||||
struct FP16Vec8 : public Vec<FP16Vec8> {
|
struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||||
constexpr static int VEC_ELEM_NUM = 8;
|
constexpr static int VEC_ELEM_NUM = 8;
|
||||||
|
|
||||||
__m128h reg;
|
__m128i reg;
|
||||||
|
|
||||||
explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
|
explicit FP16Vec8(const void *ptr)
|
||||||
|
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
||||||
|
|
||||||
explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
|
explicit FP16Vec8(const FP32Vec8 &);
|
||||||
|
|
||||||
explicit FP16Vec8(__m128h data) : reg(data) {}
|
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
||||||
|
};
|
||||||
FP16Vec8 operator*(const FP16Vec8 &b) const {
|
|
||||||
return FP16Vec8(_mm_mul_ph(reg, b.reg));
|
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||||
}
|
constexpr static int VEC_ELEM_NUM = 16;
|
||||||
|
|
||||||
FP16Vec8 operator+(const FP16Vec8 &b) const {
|
__m256i reg;
|
||||||
return FP16Vec8(_mm_add_ph(reg, b.reg));
|
|
||||||
}
|
explicit FP16Vec16(const void *ptr)
|
||||||
|
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
||||||
FP16Vec8 operator-(const FP16Vec8 &b) const {
|
|
||||||
return FP16Vec8(_mm_sub_ph(reg, b.reg));
|
explicit FP16Vec16(const FP32Vec16 &);
|
||||||
}
|
|
||||||
|
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||||
FP16Vec8 operator/(const FP16Vec8 &b) const {
|
|
||||||
return FP16Vec8(_mm_div_ph(reg, b.reg));
|
void save(void* ptr, const int elem_num) const {
|
||||||
}
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
|
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
|
||||||
void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
|
_mm256_mask_storeu_epi16(ptr, mask, reg);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
#endif
|
|
||||||
|
|
||||||
struct BF16Vec8 : public Vec<BF16Vec8> {
|
struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||||
constexpr static int VEC_ELEM_NUM = 8;
|
constexpr static int VEC_ELEM_NUM = 8;
|
||||||
@ -202,9 +202,7 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
|
|
||||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
|
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
|
||||||
|
|
||||||
#ifdef __AVX512FP16__
|
explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {}
|
||||||
explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
explicit FP32Vec8(const BF16Vec8 &v)
|
explicit FP32Vec8(const BF16Vec8 &v)
|
||||||
: reg(_mm256_castsi256_ps(
|
: reg(_mm256_castsi256_ps(
|
||||||
@ -323,6 +321,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
: reg(_mm512_castsi512_ps(
|
: reg(_mm512_castsi512_ps(
|
||||||
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const INT32Vec16 &v)
|
explicit FP32Vec16(const INT32Vec16 &v)
|
||||||
@ -534,24 +536,34 @@ template <typename T> using vec_t = typename VecType<T>::vec_type;
|
|||||||
|
|
||||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||||
|
|
||||||
#ifdef __AVX512FP16__
|
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
|
||||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||||
|
|
||||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||||
|
|
||||||
#ifdef __AVX512FP16__
|
|
||||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
|
||||||
*reinterpret_cast<_Float16 *>(ptr) = v;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||||
acc = acc + a * b;
|
acc = acc + a * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
||||||
|
*reinterpret_cast<unsigned short *>(ptr) =
|
||||||
|
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline FP16Vec8::FP16Vec8(const FP32Vec8 &v)
|
||||||
|
: reg(_mm256_cvtps_ph(v.reg,
|
||||||
|
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||||
|
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
||||||
|
: reg(_mm512_cvtps_ph(v.reg,
|
||||||
|
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||||
|
#else
|
||||||
|
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
||||||
|
: reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef __AVX512BF16__
|
#ifdef __AVX512BF16__
|
||||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||||
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
|
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#define DNNL_HELPER_HPP
|
#define DNNL_HELPER_HPP
|
||||||
|
|
||||||
#include <c10/util/BFloat16.h>
|
#include <c10/util/BFloat16.h>
|
||||||
|
#include <c10/util/Half.h>
|
||||||
|
|
||||||
#include "oneapi/dnnl/dnnl.hpp"
|
#include "oneapi/dnnl/dnnl.hpp"
|
||||||
|
|
||||||
@ -32,6 +33,11 @@ struct DNNLType<c10::BFloat16> {
|
|||||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct DNNLType<c10::Half> {
|
||||||
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||||
return DNNLType<std::decay_t<T>>::type;
|
return DNNLType<std::decay_t<T>>::type;
|
||||||
|
@ -23,6 +23,13 @@ struct KernelVecType<c10::BFloat16> {
|
|||||||
using cvt_vec_type = vec_op::FP32Vec16;
|
using cvt_vec_type = vec_op::FP32Vec16;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<c10::Half> {
|
||||||
|
using load_vec_type = vec_op::FP16Vec16;
|
||||||
|
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||||
|
using cvt_vec_type = vec_op::FP32Vec16;
|
||||||
|
};
|
||||||
|
|
||||||
#ifdef __AVX512F__
|
#ifdef __AVX512F__
|
||||||
template <bool AZP, typename scalar_t>
|
template <bool AZP, typename scalar_t>
|
||||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||||
|
@ -3,13 +3,13 @@
|
|||||||
Installation with CPU
|
Installation with CPU
|
||||||
========================
|
========================
|
||||||
|
|
||||||
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16. vLLM CPU backend supports the following vLLM features:
|
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
|
||||||
|
|
||||||
- Tensor Parallel (``-tp = N``)
|
- Tensor Parallel (``-tp = N``)
|
||||||
- Quantization (``INT8 W8A8, AWQ``)
|
- Quantization (``INT8 W8A8, AWQ``)
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
FP16 data type and more advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.
|
More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.
|
||||||
|
|
||||||
Table of contents:
|
Table of contents:
|
||||||
|
|
||||||
@ -72,8 +72,6 @@ Build from source
|
|||||||
$ VLLM_TARGET_DEVICE=cpu python setup.py install
|
$ VLLM_TARGET_DEVICE=cpu python setup.py install
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
- BF16 is the default data type in the current CPU backend (that means the backend will cast FP16 to BF16), and is compatible will all CPUs with AVX512 ISA support.
|
|
||||||
|
|
||||||
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
|
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
|
||||||
|
|
||||||
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
|
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
|
||||||
|
@ -2,5 +2,5 @@
|
|||||||
-r requirements-common.txt
|
-r requirements-common.txt
|
||||||
|
|
||||||
# Dependencies for x86_64 CPUs
|
# Dependencies for x86_64 CPUs
|
||||||
torch == 2.4.0+cpu; platform_machine != "ppc64le"
|
torch == 2.5.1+cpu; platform_machine != "ppc64le"
|
||||||
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
|
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
|
||||||
|
@ -32,8 +32,7 @@ if not current_platform.is_cpu():
|
|||||||
"openbmb/MiniCPM3-4B",
|
"openbmb/MiniCPM3-4B",
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: remove this after CPU float16 support ready
|
target_dtype = "half"
|
||||||
target_dtype = "float" if current_platform.is_cpu() else "half"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@ -2,8 +2,6 @@ import os
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
|
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
@ -316,9 +314,6 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
|
|||||||
|
|
||||||
|
|
||||||
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
||||||
if config.dtype == torch.float16:
|
|
||||||
logger.warning("float16 is not supported on CPU, casting to bfloat16.")
|
|
||||||
config.dtype = torch.bfloat16
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if not config.enforce_eager:
|
if not config.enforce_eager:
|
||||||
|
@ -54,7 +54,7 @@ class IPEXConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
return [torch.bfloat16]
|
return [torch.bfloat16, torch.float16]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user