diff --git a/Dockerfile.s390x b/Dockerfile.s390x new file mode 100644 index 00000000..b499d4cb --- /dev/null +++ b/Dockerfile.s390x @@ -0,0 +1,152 @@ +# Base UBI image for s390x architecture +ARG BASE_UBI_IMAGE_TAG=9.5-1736404155 +ARG PYTHON_VERSION=3.12 +FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base + +# Install basic dependencies +ARG PYTHON_VERSION +ENV PYTHON_VERSION=${PYTHON_VERSION} + +WORKDIR /workspace + +ENV LANG=C.UTF-8 \ + LC_ALL=C.UTF-8 + +# Install development utilities +RUN microdnf install -y \ + which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ + libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake && \ + microdnf clean all + +# Python Installation +FROM base AS python-install +ARG PYTHON_VERSION + +ENV VIRTUAL_ENV=/opt/vllm +ENV PATH="$VIRTUAL_ENV/bin:$PATH" +ENV PYTHON_VERSION=${PYTHON_VERSION} +RUN microdnf install -y \ + python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel && \ + python${PYTHON_VERSION} -m venv $VIRTUAL_ENV && pip install --no-cache -U pip wheel uv && microdnf clean all + +FROM python-install AS pyarrow + +# Build Apache Arrow +WORKDIR /tmp +RUN --mount=type=cache,target=/root/.cache/uv \ + git clone https://github.com/apache/arrow.git && \ + cd arrow/cpp && \ + mkdir release && cd release && \ + cmake -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/usr/local \ + -DARROW_PYTHON=ON \ + -DARROW_PARQUET=ON \ + -DARROW_ORC=ON \ + -DARROW_FILESYSTEM=ON \ + -DARROW_WITH_LZ4=ON \ + -DARROW_WITH_ZSTD=ON \ + -DARROW_WITH_SNAPPY=ON \ + -DARROW_JSON=ON \ + -DARROW_CSV=ON \ + -DARROW_DATASET=ON \ + -DPROTOBUF_PROTOC_EXECUTABLE=/usr/bin/protoc \ + -DARROW_DEPENDENCY_SOURCE=BUNDLED \ + .. && \ + make -j$(nproc) && \ + make install && \ + cd ../../python && \ + export PYARROW_PARALLEL=4 && \ + export ARROW_BUILD_TYPE=release && \ + uv pip install -r requirements-build.txt && \ + python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel + +FROM python-install AS numa-build +# Install numactl (needed for numa.h dependency) +WORKDIR /tmp +RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz && \ + tar -xvzf v2.0.16.tar.gz && \ + cd numactl-2.0.16 && \ + ./autogen.sh && \ + ./configure && \ + make + +# Set include path +ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" + +FROM python-install AS rust +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" + +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ + . "$CARGO_HOME/env" && \ + rustup default stable && \ + rustup show + +FROM python-install AS torch-vision +# Install torchvision +ARG TORCH_VERSION=2.7.0.dev20250304 +ARG TORCH_VISION_VERSION=v0.20.1 +WORKDIR /tmp +RUN --mount=type=cache,target=/root/.cache/uv \ + git clone https://github.com/pytorch/vision.git && \ + cd vision && \ + git checkout $TORCH_VISION_VERSION && \ + uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \ + python setup.py bdist_wheel + +# Final build stage +FROM python-install AS vllm-cpu +ARG PYTHON_VERSION + +# Set correct library path for torch and numactl +ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python${PYTHON_VERSION}/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH" +ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" +ENV UV_LINK_MODE=copy +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" + +COPY . /workspace/vllm +WORKDIR /workspace/vllm + +RUN --mount=type=bind,from=numa-build,src=/tmp/numactl-2.0.16,target=/numactl \ + make -C /numactl install + +# Install dependencies, including PyTorch and Apache Arrow +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + --mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \ + --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ + sed -i '/^torch/d' requirements-build.txt && \ + ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ + VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ + uv pip install -v \ + $ARROW_WHL_FILE \ + $VISION_WHL_FILE \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --index-strategy unsafe-best-match \ + -r requirements-build.txt \ + -r requirements-cpu.txt + +# Build and install vllm +RUN --mount=type=cache,target=/root/.cache/uv \ + VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ + uv pip install "$(echo dist/*.whl)[tensorizer]" + +# setup non-root user for vllm +RUN umask 002 && \ + useradd --uid 2000 --gid 0 vllm && \ + mkdir -p /home/vllm && \ + chmod g+rwx /home/vllm + +COPY LICENSE /licenses/vllm.md +COPY examples/*.jinja /app/data/template/ + +USER 2000 +WORKDIR /home/vllm + +# Set the default entrypoint +ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 714abca2..ca2ffb1b 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -81,6 +81,7 @@ else() find_isa(${CPUINFO} "POWER9" POWER9_FOUND) find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support + find_isa(${CPUINFO} "S390" S390_FOUND) endif() @@ -129,8 +130,16 @@ elseif (ASIMD_FOUND) elseif(APPLE_SILICON_FOUND) message(STATUS "Apple Silicon Detected") set(ENABLE_NUMA OFF) +elseif (S390_FOUND) + message(STATUS "S390 detected") + # Check for S390 VXE support + list(APPEND CXX_COMPILE_FLAGS + "-mvx" + "-mzvector" + "-march=native" + "-mtune=native") else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA or ARMv8 support.") endif() # diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index b9764056..0257d8ff 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -24,8 +24,8 @@ struct KernelVecType { template <> struct KernelVecType { -#ifdef __powerpc64__ - // Power architecture-specific vector types +#if defined(__powerpc64__) || defined(__s390x__) + // Power and s390x architecture-specific vector types using q_load_vec_type = vec_op::FP32Vec8; using k_load_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::FP32Vec16; diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index a7181510..17bbe04e 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -7,6 +7,9 @@ #elif defined(__POWER9_VECTOR__) // ppc implementation #include "cpu_types_vsx.hpp" +#elif defined(__s390x__) + // s390 implementation + #include "cpu_types_vxe.hpp" #elif defined(__aarch64__) // arm implementation #include "cpu_types_arm.hpp" diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp new file mode 100644 index 00000000..ab8cbbbf --- /dev/null +++ b/csrc/cpu/cpu_types_vxe.hpp @@ -0,0 +1,480 @@ + +#ifndef CPU_TYPES_VXE_HPP +#define CPU_TYPES_VXE_HPP + +#include +#include +#include +namespace vec_op { + +#define vec_neg(a) (-(a)) +#define vec_add(a, b) ((a) + (b)) +#define vec_sub(a, b) ((a) - (b)) +#define vec_mul(a, b) ((a) * (b)) +#define vec_div(a, b) ((a) / (b)) +#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic +#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) +#else + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F&& f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template +struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +typedef struct ss16x8x2_t { + __vector signed short val[2]; +} ss16x8x2_t; + +typedef struct ss16x8x4_t { + __vector signed short val[4]; +} ss16x8x4_t; + +typedef struct f32x4x2_t { + __vector float val[2]; +} f32x4x2_t; + +typedef struct f32x4x4_t { + __vector float val[4]; +} f32x4x4_t; + +struct FP32Vec8; +struct FP32Vec16; + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __vector signed short reg; + + explicit BF16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {} + explicit BF16Vec8(const FP32Vec8&); + + void save(void* ptr) const { + *reinterpret_cast<__vector signed short*>(ptr) = reg; + } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + ss16x8x2_t reg; + + explicit BF16Vec16(const void* ptr) { + // Load 256 bits in two parts + reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr); + } + + explicit BF16Vec16(const FP32Vec16&); + + void save(void* ptr) const { + // Save 256 bits in two parts + vec_xst(reg.val[0], 0, (signed short*)ptr); + vec_xst(reg.val[1], 16, (signed short*)ptr); + } +}; + +const static __vector signed short zero = vec_splats((signed short)0); + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + ss16x8x4_t reg; + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {} + + explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} + + explicit BF16Vec32(const BF16Vec8& vec8_data) + : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {} + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __vector float reg; + float values[VEC_ELEM_NUM]; + }; + + __vector float reg; + + explicit FP32Vec4(float v) : reg(vec_splats(v)) {} + + explicit FP32Vec4() : reg(vec_splats(0.0f)) {} + + explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {} + + explicit FP32Vec4(__vector float data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + f32x4x2_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x2_t reg; + + explicit FP32Vec8(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + } + + explicit FP32Vec8() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + } + + explicit FP32Vec8(const float* ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + } + + explicit FP32Vec8(f32x4x2_t data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + } + + explicit FP32Vec8(const BF16Vec8& v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::exp(ar.values[0]); + ret.val[0][1] = std::exp(ar.values[1]); + ret.val[0][2] = std::exp(ar.values[2]); + ret.val[0][3] = std::exp(ar.values[3]); + ret.val[1][0] = std::exp(ar.values[4]); + ret.val[1][1] = std::exp(ar.values[5]); + ret.val[1][2] = std::exp(ar.values[6]); + ret.val[1][3] = std::exp(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 tanh() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::tanh(ar.values[0]); + ret.val[0][1] = std::tanh(ar.values[1]); + ret.val[0][2] = std::tanh(ar.values[2]); + ret.val[0][3] = std::tanh(ar.values[3]); + ret.val[1][0] = std::tanh(ar.values[4]); + ret.val[1][1] = std::tanh(ar.values[5]); + ret.val[1][2] = std::tanh(ar.values[6]); + ret.val[1][3] = std::tanh(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 er() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::erf(ar.values[0]); + ret.val[0][1] = std::erf(ar.values[1]); + ret.val[0][2] = std::erf(ar.values[2]); + ret.val[0][3] = std::erf(ar.values[3]); + ret.val[1][0] = std::erf(ar.values[4]); + ret.val[1][1] = std::erf(ar.values[5]); + ret.val[1][2] = std::erf(ar.values[6]); + ret.val[1][3] = std::erf(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 operator*(const FP32Vec8& b) const { + return FP32Vec8( + {vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator+(const FP32Vec8& b) const { + return FP32Vec8( + {vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator-(const FP32Vec8& b) const { + return FP32Vec8( + {vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator/(const FP32Vec8& b) const { + return FP32Vec8( + {vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + } + + void save(float* ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + f32x4x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x4_t reg; + + explicit FP32Vec16(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + reg.val[2] = vec_splats(v); + reg.val[3] = vec_splats(v); + } + + explicit FP32Vec16() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + reg.val[2] = vec_splats(0.0f); + reg.val[3] = vec_splats(0.0f); + } + + explicit FP32Vec16(const float* ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + reg.val[2] = vec_xl(32, ptr); + reg.val[3] = vec_xl(48, ptr); + } + + explicit FP32Vec16(f32x4x4_t data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[2]; + reg.val[3] = data.reg.val[3]; + } + + explicit FP32Vec16(const FP32Vec4& data) { + reg.val[0] = data.reg; + reg.val[1] = data.reg; + reg.val[2] = data.reg; + reg.val[3] = data.reg; + } + + explicit FP32Vec16(const FP32Vec8& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[0]; + reg.val[3] = data.reg.val[1]; + } + + explicit FP32Vec16(const BF16Vec16& v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); + reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); + reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); + } + + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator+(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator-(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator/(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + template + float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + + AliasReg ar; + ar.reg = reg; + float result = 0; + const int start = idx * group_size; + unroll_loop( + [&result, &start, ar](int i) { result += ar.values[start + i]; }); + + return result; + } + + void save(float* ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + vec_xst(reg.val[2], 32, ptr); + vec_xst(reg.val[3], 48, ptr); + } +}; + +template +struct VecType { + using vec_type = void; +}; + +template +using vec_t = typename VecType::vec_type; + +template <> +struct VecType { + using vec_type = FP32Vec8; +}; + +template <> +struct VecType { + using vec_type = BF16Vec8; +}; + +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} + +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { + acc = acc + a * b; +} + +namespace c10 { +struct BFloat16 { + uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit + // value. +}; +} // namespace c10 + +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +#ifndef __VEC_CLASS_FP_NAN + #define __VEC_CLASS_FP_NAN (1 << 6) +#endif + +const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15, + 18, 19, 22, 23, 26, 27, 30, 31}; +const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff, + 0x00007fff}; +const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000, + 0x7fc00000}; +const static __vector unsigned int sh16 = {16, 16, 16, 16}; +const static __vector unsigned int one = {1, 1, 1, 1}; + +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + int cc; + __vector __bool int sel0 = + vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel1 = + vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc); + inp0 = vec_sel(inp0, nan, sel0) >> sh16; + inp1 = vec_sel(inp1, nan, sel1) >> sh16; + reg = (__vector signed short)vec_perm(inp0, inp1, omask); +} + +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]); + __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]); + int cc; + __vector __bool int sel0 = + vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel1 = + vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel2 = + vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel3 = + vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc); + inp0 = vec_sel(inp0, nan, sel0) >> sh16; + inp1 = vec_sel(inp1, nan, sel1) >> sh16; + inp2 = vec_sel(inp2, nan, sel2) >> sh16; + inp3 = vec_sel(inp3, nan, sel3) >> sh16; + reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask); + reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); +} + +inline void prefetch(const void* addr) { void __dcbt(const void* addr); } + +}; // namespace vec_op + +#endif \ No newline at end of file diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 33b16378..6751e7e5 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -25,7 +25,7 @@ struct KernelVecType { template <> struct KernelVecType { -#ifdef __powerpc64__ +#if defined(__powerpc64__) || defined(__s390x__) // Power architecture-specific vector type using load_vec_type = vec_op::FP32Vec16; #else diff --git a/requirements-cpu.txt b/requirements-cpu.txt index ecfa822e..9491e27d 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,14 +2,15 @@ -r requirements-common.txt # Dependencies for CPUs -torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin" -torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin" +torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin" and platform_machine != "s390x" +torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin" +torch==2.7.0.dev20250304; platform_machine == "s390x" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch -torchaudio; platform_machine != "ppc64le" +torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" torchaudio==2.5.1; platform_machine == "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch -torchvision; platform_machine != "ppc64le" +torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" torchvision==0.20.1; platform_machine == "ppc64le" datasets # for benchmark scripts