Move linting to pre-commit (#11975)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-01-20 06:58:01 +00:00 committed by GitHub
parent 51ef828f10
commit 3ea7b94523
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 724 additions and 1286 deletions

View File

@ -43,7 +43,7 @@ main() {
# The figures should be genereated by a separate process outside the CI/CD pipeline # The figures should be generated by a separate process outside the CI/CD pipeline
# # generate figures # # generate figures
# python3 -m pip install tabulate pandas matplotlib # python3 -m pip install tabulate pandas matplotlib

View File

@ -1,40 +0,0 @@
name: Lint GitHub Actions workflows
on:
push:
branches:
- "main"
paths:
- '.github/workflows/*.ya?ml'
- '.github/workflows/actionlint.*'
- '.github/workflows/matchers/actionlint.json'
pull_request:
branches:
- "main"
paths:
- '.github/workflows/*.ya?ml'
- '.github/workflows/actionlint.*'
- '.github/workflows/matchers/actionlint.json'
env:
LC_ALL: en_US.UTF-8
defaults:
run:
shell: bash
permissions:
contents: read
jobs:
actionlint:
runs-on: ubuntu-latest
steps:
- name: "Checkout"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: "Run actionlint"
run: |
echo "::add-matcher::.github/workflows/matchers/actionlint.json"
tools/actionlint.sh -color

View File

@ -1,53 +0,0 @@
name: clang-format
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- '**/*.h'
- '**/*.cpp'
- '**/*.cu'
- '**/*.cuh'
- '.github/workflows/clang-format.yml'
pull_request:
branches:
- main
paths:
- '**/*.h'
- '**/*.cpp'
- '**/*.cu'
- '**/*.cuh'
- '.github/workflows/clang-format.yml'
jobs:
clang-format:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install clang-format==18.1.5
- name: Running clang-format
run: |
EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/quantization/gguf/ggml-common.h'
'csrc/quantization/gguf/dequantize.cuh'
'csrc/quantization/gguf/vecdotq.cuh'
'csrc/quantization/gguf/mmq.cuh'
'csrc/quantization/gguf/mmvq.cuh'
)
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
| xargs clang-format --dry-run --Werror

View File

@ -1,45 +0,0 @@
name: codespell
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- "**/*.md"
- "**/*.rst"
- pyproject.toml
- requirements-lint.txt
- .github/workflows/codespell.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- "**/*.md"
- "**/*.rst"
- pyproject.toml
- requirements-lint.txt
- .github/workflows/codespell.yml
jobs:
codespell:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-lint.txt
- name: Spelling check with codespell
run: |
codespell --toml pyproject.toml

View File

@ -1,32 +0,0 @@
name: Lint documentation
on:
push:
branches:
- main
paths:
- "docs/**"
pull_request:
branches:
- main
paths:
- "docs/**"
jobs:
doc-lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-lint.txt
- name: Linting docs
run: tools/doc-lint.sh

20
.github/workflows/dummy.yml vendored Normal file
View File

@ -0,0 +1,20 @@
name: dummy-checks
on:
pull_request:
jobs:
mypy:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- run: echo "This is a dummy step that always passes"
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- run: echo "This is a dummy step that always passes"

View File

@ -1,17 +0,0 @@
{
"problemMatcher": [
{
"owner": "ruff",
"pattern": [
{
"regexp": "^(.+?):(\\d+):(\\d+): (\\w+): (.+)$",
"file": 1,
"line": 2,
"column": 3,
"code": 4,
"message": 5
}
]
}
]
}

View File

@ -1,51 +0,0 @@
name: mypy
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- '**/*.py'
- '.github/workflows/mypy.yaml'
- 'tools/mypy.sh'
- 'pyproject.toml'
pull_request:
branches:
- main
# This workflow is only relevant when one of the following files changes.
# However, we have github configured to expect and require this workflow
# to run and pass before github with auto-merge a pull request. Until github
# allows more flexible auto-merge policy, we can just run this on every PR.
# It doesn't take that long to run, anyway.
#paths:
# - '**/*.py'
# - '.github/workflows/mypy.yaml'
# - 'tools/mypy.sh'
# - 'pyproject.toml'
jobs:
mypy:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.11.1
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
pip install types-setuptools
- name: Mypy
run: |
echo "::add-matcher::.github/workflows/matchers/mypy.json"
tools/mypy.sh 1 ${{ matrix.python-version }}

View File

@ -1,37 +0,0 @@
name: Lint PNG exports from excalidraw
on:
push:
branches:
- "main"
paths:
- '*.excalidraw.png'
- '.github/workflows/png-lint.yml'
pull_request:
branches:
- "main"
paths:
- '*.excalidraw.png'
- '.github/workflows/png-lint.yml'
env:
LC_ALL: en_US.UTF-8
defaults:
run:
shell: bash
permissions:
contents: read
jobs:
actionlint:
runs-on: ubuntu-latest
steps:
- name: "Checkout"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: "Run png-lint.sh to check excalidraw exported images"
run: |
tools/png-lint.sh

17
.github/workflows/pre-commit.yml vendored Normal file
View File

@ -0,0 +1,17 @@
name: pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "3.12"
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1

View File

@ -1,52 +0,0 @@
name: ruff
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- pyproject.toml
- requirements-lint.txt
- .github/workflows/matchers/ruff.json
- .github/workflows/ruff.yml
pull_request:
branches:
- main
# This workflow is only relevant when one of the following files changes.
# However, we have github configured to expect and require this workflow
# to run and pass before github with auto-merge a pull request. Until github
# allows more flexible auto-merge policy, we can just run this on every PR.
# It doesn't take that long to run, anyway.
#paths:
# - "**/*.py"
# - pyproject.toml
# - requirements-lint.txt
# - .github/workflows/matchers/ruff.json
# - .github/workflows/ruff.yml
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-lint.txt
- name: Analysing the code with ruff
run: |
echo "::add-matcher::.github/workflows/matchers/ruff.json"
ruff check --output-format github .
- name: Run isort
run: |
isort . --check-only

View File

@ -1,37 +0,0 @@
name: Lint shell scripts
on:
push:
branches:
- "main"
paths:
- '**/*.sh'
- '.github/workflows/shellcheck.yml'
pull_request:
branches:
- "main"
paths:
- '**/*.sh'
- '.github/workflows/shellcheck.yml'
env:
LC_ALL: en_US.UTF-8
defaults:
run:
shell: bash
permissions:
contents: read
jobs:
shellcheck:
runs-on: ubuntu-latest
steps:
- name: "Checkout"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: "Check shell scripts"
run: |
tools/shellcheck.sh

View File

@ -1,38 +0,0 @@
name: yapf
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/yapf.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/yapf.yml
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install yapf==0.32.0
pip install toml==0.10.2
- name: Running yapf
run: |
yapf --diff --recursive .

73
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,73 @@
repos:
- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
- id: yapf
args: [--in-place, --verbose]
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.5
hooks:
- id: ruff
args: [--output-format, github]
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
hooks:
- id: codespell
exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*'
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.5
hooks:
- id: clang-format
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))'
types_or: [c++, cuda]
args: [--style=file, --verbose]
- repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.27
hooks:
- id: pymarkdown
files: docs/.*
- repo: local
hooks:
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9
entry: tools/mypy.sh 1 "3.9"
language: python
types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10
entry: tools/mypy.sh 1 "3.10"
language: python
types: [python]
additional_dependencies: *mypy_deps
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.11
entry: tools/mypy.sh 1 "3.11"
language: python
types: [python]
additional_dependencies: *mypy_deps
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.12
entry: tools/mypy.sh 1 "3.12"
language: python
types: [python]
additional_dependencies: *mypy_deps
- id: shellcheck
name: Lint shell scripts
entry: tools/shellcheck.sh
language: script
types: [shell]
- id: png-lint
name: Lint PNG exports from excalidraw
entry: tools/png-lint.sh
language: script
types: [png]
- repo: https://github.com/rhysd/actionlint
rev: v1.7.6
hooks:
- id: actionlint

View File

@ -32,7 +32,7 @@ class ScalarType {
signed_(signed_), signed_(signed_),
bias(bias), bias(bias),
finite_values_only(finite_values_only), finite_values_only(finite_values_only),
nan_repr(nan_repr){}; nan_repr(nan_repr) {};
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits - 1, true, bias); return ScalarType(0, size_bits - 1, true, bias);

View File

@ -2,13 +2,13 @@
#define CPU_TYPES_HPP #define CPU_TYPES_HPP
#if defined(__x86_64__) #if defined(__x86_64__)
//x86 implementation // x86 implementation
#include "cpu_types_x86.hpp" #include "cpu_types_x86.hpp"
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
//ppc implementation // ppc implementation
#include "cpu_types_vsx.hpp" #include "cpu_types_vsx.hpp"
#elif defined(__aarch64__) #elif defined(__aarch64__)
//arm implementation // arm implementation
#include "cpu_types_arm.hpp" #include "cpu_types_arm.hpp"
#else #else
#warning "unsupported vLLM cpu implementation" #warning "unsupported vLLM cpu implementation"

View File

@ -1,48 +1,50 @@
#include <arm_neon.h> #include <arm_neon.h>
#include <torch/all.h> #include <torch/all.h>
#include <cmath> #include <cmath>
namespace vec_op { namespace vec_op {
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#else #else
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#endif #endif
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD #ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME) #define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME) #define CPU_KERNEL_GUARD_OUT(NAME)
#else #else
#define CPU_KERNEL_GUARD_IN(NAME) \ #define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl; std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; #define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif #endif
#define FORCE_INLINE __attribute__((always_inline)) inline #define FORCE_INLINE __attribute__((always_inline)) inline
namespace { namespace {
template <typename T, T... indexes, typename F> template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) { constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...); (f(std::integral_constant<T, indexes>{}), ...);
}; };
}; }; // namespace
template <typename T, T count, typename F, template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>> typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F &&f) { constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f)); unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
} }
template <typename T> struct Vec { template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }; constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
}; };
@ -54,127 +56,124 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
float16x8_t reg; float16x8_t reg;
explicit FP16Vec8(const void *ptr) explicit FP16Vec8(const void* ptr)
: reg(vld1q_f16(static_cast<const __fp16 *>(ptr))) {}; : reg(vld1q_f16(static_cast<const __fp16*>(ptr))) {};
explicit FP16Vec8(const FP32Vec8 &); explicit FP16Vec8(const FP32Vec8&);
void save(void *ptr) const { void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); }
vst1q_f16(static_cast<__fp16 *>(ptr), reg);
}
}; };
struct FP16Vec16 : public Vec<FP16Vec16> { struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16; constexpr static int VEC_ELEM_NUM = 16;
float16x8x2_t reg;
explicit FP16Vec16(const void *ptr) {
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
}
explicit FP16Vec16(const FP32Vec16& vec);
void save(void *ptr) const {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
}
void save(void *ptr, const int elem_num) const {
int full_blocks = elem_num / 8;
int remainder = elem_num % 8;
if (full_blocks > 0) {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
if (full_blocks > 1) {
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
}
}
// Note: below is the unrolled version of the following code: float16x8x2_t reg;
//
// for (int i = 0; i < remainder; ++i) { explicit FP16Vec16(const void* ptr) {
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
// vgetq_lane_f16(temp, i); reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
// } }
//
// For macOS build (Clang), the arm/neon intrinsics function explicit FP16Vec16(const FP32Vec16& vec);
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
// time. void save(void* ptr) const {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
if (remainder > 0) { vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
float16x8_t temp = reg.val[full_blocks]; }
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
switch (remainder) void save(void* ptr, const int elem_num) const {
{ int full_blocks = elem_num / 8;
case 1: int remainder = elem_num % 8;
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
break; if (full_blocks > 0) {
case 2: vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); if (full_blocks > 1) {
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
break; }
case 3:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
break;
case 4:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
break;
case 5:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
break;
case 6:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
break;
case 7:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
break;
default:
break;
}
}
} }
// Note: below is the unrolled version of the following code:
//
// for (int i = 0; i < remainder; ++i) {
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] =
// vgetq_lane_f16(temp, i);
// }
//
// For macOS build (Clang), the arm/neon intrinsics function
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
// time.
if (remainder > 0) {
float16x8_t temp = reg.val[full_blocks];
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
switch (remainder) {
case 1:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
break;
case 2:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
break;
case 3:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
break;
case 4:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
break;
case 5:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
break;
case 6:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
break;
case 7:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
break;
default:
break;
}
}
}
}; };
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
struct BF16Vec8 : public Vec<BF16Vec8> { struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8; constexpr static int VEC_ELEM_NUM = 8;
bfloat16x8_t reg; bfloat16x8_t reg;
explicit BF16Vec8(const void *ptr) explicit BF16Vec8(const void* ptr)
: reg(*reinterpret_cast<const bfloat16x8_t *>(ptr)) {}; : reg(*reinterpret_cast<const bfloat16x8_t*>(ptr)) {};
explicit BF16Vec8(bfloat16x8_t data) : reg(data) {}; explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
explicit BF16Vec8(const FP32Vec8 &); explicit BF16Vec8(const FP32Vec8&);
explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {}; explicit BF16Vec8(float32x4x2_t v)
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
void save(void *ptr) const { *reinterpret_cast<bfloat16x8_t *>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<bfloat16x8_t*>(ptr) = reg; }
}; };
struct BF16Vec16 : public Vec<BF16Vec16> { struct BF16Vec16 : public Vec<BF16Vec16> {
@ -182,19 +181,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
bfloat16x8x2_t reg; bfloat16x8x2_t reg;
explicit BF16Vec16(const void *ptr) explicit BF16Vec16(const void* ptr)
: reg(*reinterpret_cast<const bfloat16x8x2_t *>(ptr)) {}; : reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {};
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {}; explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
explicit BF16Vec16(const FP32Vec16 &); explicit BF16Vec16(const FP32Vec16&);
explicit BF16Vec16(float32x4x4_t v) : reg({ explicit BF16Vec16(float32x4x4_t v)
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]), : reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3]) vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
}){};
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x2_t *>(ptr) = reg; }; void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
}; };
struct BF16Vec32 : public Vec<BF16Vec32> { struct BF16Vec32 : public Vec<BF16Vec32> {
@ -202,19 +200,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
bfloat16x8x4_t reg; bfloat16x8x4_t reg;
explicit BF16Vec32(const void *ptr) explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const bfloat16x8x4_t *>(ptr)) {}; : reg(*reinterpret_cast<const bfloat16x8x4_t*>(ptr)) {};
explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {}; explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ explicit BF16Vec32(const BF16Vec8& vec8_data)
vec8_data.reg, : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
vec8_data.reg,
vec8_data.reg,
vec8_data.reg
}) {};
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x4_t *>(ptr) = reg; }; void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
}; };
#endif #endif
@ -232,11 +226,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {}; explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {}; explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {};
explicit FP32Vec4(float32x4_t data) : reg(data) {}; explicit FP32Vec4(float32x4_t data) : reg(data) {};
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}; explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
}; };
struct FP32Vec8 : public Vec<FP32Vec8> { struct FP32Vec8 : public Vec<FP32Vec8> {
@ -252,32 +246,37 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}; explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {}; explicit FP32Vec8(const float* ptr)
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
explicit FP32Vec8(float32x4x2_t data) : reg(data) {}; explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}; explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8 &v) { explicit FP32Vec8(const FP16Vec8& v) {
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg)); reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg)); reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
}; };
explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {}; explicit FP32Vec8(float16x8_t v)
: reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {}; explicit FP32Vec8(bfloat16x8_t v)
: reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {}; explicit FP32Vec8(const BF16Vec8& v)
: reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
#endif #endif
float reduce_sum() const { float reduce_sum() const {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float answer = 0; float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; }); unroll_loop<int, VEC_ELEM_NUM>(
[&answer, &ar](int i) { answer += ar.values[i]; });
return answer; return answer;
} }
@ -324,10 +323,14 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])), static_cast<float32_t>(erf(ar.values[1]))}; float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])),
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])), static_cast<float32_t>(erf(ar.values[3]))}; static_cast<float32_t>(erf(ar.values[1]))};
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])), static_cast<float32_t>(erf(ar.values[5]))}; float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])),
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])), static_cast<float32_t>(erf(ar.values[7]))}; static_cast<float32_t>(erf(ar.values[3]))};
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])),
static_cast<float32_t>(erf(ar.values[5]))};
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])),
static_cast<float32_t>(erf(ar.values[7]))};
float32x4_t result0 = vcombine_f32(er_vec0, er_vec1); float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
float32x4_t result1 = vcombine_f32(er_vec2, er_vec3); float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
@ -337,25 +340,29 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
result.val[1] = result1; result.val[1] = result1;
return FP32Vec8(result); return FP32Vec8(result);
}
FP32Vec8 operator*(const FP32Vec8 &b) const {
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])}));
} }
FP32Vec8 operator+(const FP32Vec8 &b) const { FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])})); return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]),
vmulq_f32(reg.val[1], b.reg.val[1])}));
} }
FP32Vec8 operator-(const FP32Vec8 &b) const { FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])})); return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]),
vaddq_f32(reg.val[1], b.reg.val[1])}));
} }
FP32Vec8 operator/(const FP32Vec8 &b) const { FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])})); return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]),
vsubq_f32(reg.val[1], b.reg.val[1])}));
} }
void save(float *ptr) const { FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]),
vdivq_f32(reg.val[1], b.reg.val[1])}));
}
void save(float* ptr) const {
vst1q_f32(ptr, reg.val[0]); vst1q_f32(ptr, reg.val[0]);
vst1q_f32(ptr + 4, reg.val[1]); vst1q_f32(ptr + 4, reg.val[1]);
} }
@ -370,103 +377,100 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float32x4x4_t reg; float32x4x4_t reg;
explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {} explicit FP32Vec16(float v)
: reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {} explicit FP32Vec16()
: reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0),
vmovq_n_f32(0.0)}) {}
explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {} explicit FP32Vec16(const float* ptr)
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
vld1q_f32(ptr + 12)}) {}
explicit FP32Vec16(float32x4x4_t data) : reg(data) {} explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec8 &data) { explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0]; reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1]; reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0]; reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1]; reg.val[3] = data.reg.val[1];
} }
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {} explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {}
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(bfloat16x8x2_t v) : reg({ explicit FP32Vec16(bfloat16x8x2_t v)
vcvtq_low_f32_bf16(v.val[0]), : reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]),
vcvtq_high_f32_bf16(v.val[0]), vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {};
vcvtq_low_f32_bf16(v.val[1]), #endif
vcvtq_high_f32_bf16(v.val[1])
}) {};
#endif
explicit FP32Vec16(const FP32Vec4 &data) { explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg; reg.val[0] = data.reg;
reg.val[1] = data.reg; reg.val[1] = data.reg;
reg.val[2] = data.reg; reg.val[2] = data.reg;
reg.val[3] = data.reg; reg.val[3] = data.reg;
}; };
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(const BF16Vec16 &v) : reg({ explicit FP32Vec16(const BF16Vec16& v)
vcvtq_low_f32_bf16(v.reg.val[0]), : reg({vcvtq_low_f32_bf16(v.reg.val[0]),
vcvtq_high_f32_bf16(v.reg.val[0]), vcvtq_high_f32_bf16(v.reg.val[0]),
vcvtq_low_f32_bf16(v.reg.val[1]), vcvtq_low_f32_bf16(v.reg.val[1]),
vcvtq_high_f32_bf16(v.reg.val[1]) vcvtq_high_f32_bf16(v.reg.val[1])}) {};
}) {};
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}; explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
#endif #endif
explicit FP32Vec16(const FP16Vec16 &v) { explicit FP32Vec16(const FP16Vec16& v) {
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0])); reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0])); reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
}; };
FP32Vec16 operator+(const FP32Vec16 &b) const { FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({ return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1]),
vaddq_f32(reg.val[1], b.reg.val[1]), vaddq_f32(reg.val[2], b.reg.val[2]),
vaddq_f32(reg.val[2], b.reg.val[2]), vaddq_f32(reg.val[3], b.reg.val[3])}));
vaddq_f32(reg.val[3], b.reg.val[3])}));
}; };
FP32Vec16 operator*(const FP32Vec16 &b) const { FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({ return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]),
vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1]),
vmulq_f32(reg.val[1], b.reg.val[1]), vmulq_f32(reg.val[2], b.reg.val[2]),
vmulq_f32(reg.val[2], b.reg.val[2]), vmulq_f32(reg.val[3], b.reg.val[3])}));
vmulq_f32(reg.val[3], b.reg.val[3])}));
}; };
FP32Vec16 operator-(const FP32Vec16 &b) const { FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({ return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]),
vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1]),
vsubq_f32(reg.val[1], b.reg.val[1]), vsubq_f32(reg.val[2], b.reg.val[2]),
vsubq_f32(reg.val[2], b.reg.val[2]), vsubq_f32(reg.val[3], b.reg.val[3])}));
vsubq_f32(reg.val[3], b.reg.val[3])
}));
}; };
FP32Vec16 operator/(const FP32Vec16 &b) const { FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({ return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]),
vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1]),
vdivq_f32(reg.val[1], b.reg.val[1]), vdivq_f32(reg.val[2], b.reg.val[2]),
vdivq_f32(reg.val[2], b.reg.val[2]), vdivq_f32(reg.val[3], b.reg.val[3])}));
vdivq_f32(reg.val[3], b.reg.val[3])
}));
}; };
float reduce_sum() const { float reduce_sum() const {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float answer = 0; float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; }); unroll_loop<int, VEC_ELEM_NUM>(
[&answer, &ar](int i) { answer += ar.values[i]; });
return answer; return answer;
}; };
template <int group_size> float reduce_sub_sum(int idx) { template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0); static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar; AliasReg ar;
@ -479,7 +483,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return answer; return answer;
}; };
void save(float *ptr) const { void save(float* ptr) const {
vst1q_f32(ptr, reg.val[0]); vst1q_f32(ptr, reg.val[0]);
vst1q_f32(ptr + 4, reg.val[1]); vst1q_f32(ptr + 4, reg.val[1]);
vst1q_f32(ptr + 8, reg.val[2]); vst1q_f32(ptr + 8, reg.val[2]);
@ -487,43 +491,59 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
}; };
}; };
template <typename T> struct VecType { using vec_type = void; }; template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T> using vec_t = typename VecType<T>::vec_type; template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <> struct VecType<float> { using vec_type = FP32Vec8; }; template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; }; template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; }; template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
#endif #endif
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; } template <typename T>
void storeFP32(float v, T* ptr) {
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) { *ptr = v;
*reinterpret_cast<__fp16 *>(ptr) = v;
} }
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) { template <>
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]); inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]); *reinterpret_cast<__fp16*>(ptr) = v;
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]); }
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
reg.val[0] = vcombine_f16(low_0, high_0); inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
reg.val[1] = vcombine_f16(low_1, high_1); float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
reg.val[0] = vcombine_f16(low_0, high_0);
reg.val[1] = vcombine_f16(low_1, high_1);
}; };
inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) { inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]); float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]); float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
reg = vcombine_f16(lower_half, upper_half); reg = vcombine_f16(lower_half, upper_half);
}; };
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]); acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]); acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]); acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
@ -531,8 +551,7 @@ inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
}; };
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0])); float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0])); float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1])); float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
@ -551,22 +570,22 @@ inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
#endif #endif
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {}; inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {
};
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({ inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]), : reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3]) vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]),
}){}; v.reg.val[3])}) {};
#endif #endif
inline void prefetch(const void *addr) { inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); };
__builtin_prefetch(addr, 0, 1);
};
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
template <> template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) { inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
*reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v); *reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v);
}; };
#endif #endif
}; }; // namespace vec_op

View File

@ -9,38 +9,40 @@
namespace vec_op { namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU // FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD #ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME) #define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME) #define CPU_KERNEL_GUARD_OUT(NAME)
#else #else
#define CPU_KERNEL_GUARD_IN(NAME) \ #define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl; std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; #define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif #endif
#define FORCE_INLINE __attribute__((always_inline)) inline #define FORCE_INLINE __attribute__((always_inline)) inline
namespace { namespace {
template <typename T, T... indexes, typename F> template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) { constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...); (f(std::integral_constant<T, indexes>{}), ...);
} }
}; // namespace }; // namespace
template <typename T, T count, typename F, template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>> typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F &&f) { constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f)); unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
} }
template <typename T> struct Vec { template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
}; };
@ -68,12 +70,14 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
__vector signed short reg; __vector signed short reg;
explicit BF16Vec8(const void *ptr) explicit BF16Vec8(const void* ptr)
: reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} : reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {}
explicit BF16Vec8(const FP32Vec8 &); explicit BF16Vec8(const FP32Vec8&);
void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } void save(void* ptr) const {
*reinterpret_cast<__vector signed short*>(ptr) = reg;
}
}; };
struct BF16Vec16 : public Vec<BF16Vec16> { struct BF16Vec16 : public Vec<BF16Vec16> {
@ -81,18 +85,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
ss16x8x2_t reg; ss16x8x2_t reg;
explicit BF16Vec16(const void *ptr) { explicit BF16Vec16(const void* ptr) {
// Load 256 bits in two parts // Load 256 bits in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
} }
explicit BF16Vec16(const FP32Vec16 &); explicit BF16Vec16(const FP32Vec16&);
void save(void *ptr) const { void save(void* ptr) const {
// Save 256 bits in two parts // Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short *)ptr); vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short *)ptr); vec_xst(reg.val[1], 16, (signed short*)ptr);
} }
}; };
@ -102,19 +106,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32; constexpr static int VEC_ELEM_NUM = 32;
ss16x8x4_t reg; ss16x8x4_t reg;
explicit BF16Vec32(const void *ptr) explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const ss16x8x4_t *>(ptr)) {} : reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ explicit BF16Vec32(const BF16Vec8& vec8_data)
vec8_data.reg, : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
vec8_data.reg,
vec8_data.reg,
vec8_data.reg
}) {}
void save(void *ptr) const { *reinterpret_cast<ss16x8x4_t *>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
}; };
struct FP32Vec4 : public Vec<FP32Vec4> { struct FP32Vec4 : public Vec<FP32Vec4> {
@ -130,11 +130,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
explicit FP32Vec4() : reg(vec_splats(0.0f)) {} explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
explicit FP32Vec4(__vector float data) : reg(data) {} explicit FP32Vec4(__vector float data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
}; };
struct FP32Vec8 : public Vec<FP32Vec8> { struct FP32Vec8 : public Vec<FP32Vec8> {
@ -156,19 +156,19 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
reg.val[1] = vec_splats(0.0f); reg.val[1] = vec_splats(0.0f);
} }
explicit FP32Vec8(const float *ptr) { explicit FP32Vec8(const float* ptr) {
reg.val[0] = vec_xl(0, ptr); reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr); reg.val[1] = vec_xl(16, ptr);
} }
explicit FP32Vec8(f32x4x2_t data) : reg(data) {} explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8 &data) { explicit FP32Vec8(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0]; reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1]; reg.val[1] = data.reg.val[1];
} }
explicit FP32Vec8(const BF16Vec8 &v) { explicit FP32Vec8(const BF16Vec8& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg); reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
} }
@ -177,7 +177,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float result = 0; float result = 0;
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; }); unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result; return result;
} }
@ -230,23 +231,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
} }
FP32Vec8 operator*(const FP32Vec8 &b) const { FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); return FP32Vec8(
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
} }
FP32Vec8 operator+(const FP32Vec8 &b) const { FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); return FP32Vec8(
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
} }
FP32Vec8 operator-(const FP32Vec8 &b) const { FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); return FP32Vec8(
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
} }
FP32Vec8 operator/(const FP32Vec8 &b) const { FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); return FP32Vec8(
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
} }
void save(float *ptr) const { void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr); vec_xst(reg.val[1], 16, ptr);
} }
@ -275,7 +280,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg.val[3] = vec_splats(0.0f); reg.val[3] = vec_splats(0.0f);
} }
explicit FP32Vec16(const float *ptr) { explicit FP32Vec16(const float* ptr) {
reg.val[0] = vec_xl(0, ptr); reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr); reg.val[1] = vec_xl(16, ptr);
reg.val[2] = vec_xl(32, ptr); reg.val[2] = vec_xl(32, ptr);
@ -284,78 +289,76 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(f32x4x4_t data) : reg(data) {} explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16 &data) { explicit FP32Vec16(const FP32Vec16& data) {
reg.val[0] = data.reg.val[0]; reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1]; reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[2]; reg.val[2] = data.reg.val[2];
reg.val[3] = data.reg.val[3]; reg.val[3] = data.reg.val[3];
} }
explicit FP32Vec16(const FP32Vec4 &data) { explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg; reg.val[0] = data.reg;
reg.val[1] = data.reg; reg.val[1] = data.reg;
reg.val[2] = data.reg; reg.val[2] = data.reg;
reg.val[3] = data.reg; reg.val[3] = data.reg;
} }
explicit FP32Vec16(const FP32Vec8 &data) { explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0]; reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1]; reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0]; reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1]; reg.val[3] = data.reg.val[1];
} }
explicit FP32Vec16(const BF16Vec16 &v) { explicit FP32Vec16(const BF16Vec16& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
} }
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const { FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({ return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1]),
vec_mul(reg.val[1], b.reg.val[1]), vec_mul(reg.val[2], b.reg.val[2]),
vec_mul(reg.val[2], b.reg.val[2]), vec_mul(reg.val[3], b.reg.val[3])}));
vec_mul(reg.val[3], b.reg.val[3])}));
} }
FP32Vec16 operator+(const FP32Vec16 &b) const { FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({ return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1]),
vec_add(reg.val[1], b.reg.val[1]), vec_add(reg.val[2], b.reg.val[2]),
vec_add(reg.val[2], b.reg.val[2]), vec_add(reg.val[3], b.reg.val[3])}));
vec_add(reg.val[3], b.reg.val[3])}));
} }
FP32Vec16 operator-(const FP32Vec16 &b) const { FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({ return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1]),
vec_sub(reg.val[1], b.reg.val[1]), vec_sub(reg.val[2], b.reg.val[2]),
vec_sub(reg.val[2], b.reg.val[2]), vec_sub(reg.val[3], b.reg.val[3])}));
vec_sub(reg.val[3], b.reg.val[3])}));
} }
FP32Vec16 operator/(const FP32Vec16 &b) const { FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({ return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1]),
vec_div(reg.val[1], b.reg.val[1]), vec_div(reg.val[2], b.reg.val[2]),
vec_div(reg.val[2], b.reg.val[2]), vec_div(reg.val[3], b.reg.val[3])}));
vec_div(reg.val[3], b.reg.val[3])}));
} }
float reduce_sum() const { float reduce_sum() const {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float result = 0; float result = 0;
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; }); unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result; return result;
} }
template <int group_size> float reduce_sub_sum(int idx) { template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0); static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar; AliasReg ar;
@ -368,7 +371,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return result; return result;
} }
void save(float *ptr) const { void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr); vec_xst(reg.val[1], 16, ptr);
vec_xst(reg.val[2], 32, ptr); vec_xst(reg.val[2], 32, ptr);
@ -376,43 +379,62 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
} }
}; };
template <typename T> struct VecType { using vec_type = void; }; template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T> using vec_t = typename VecType<T>::vec_type; template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <> struct VecType<float> { using vec_type = FP32Vec8; }; template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; }; template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; } template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc = acc + a * b; acc = acc + a * b;
} }
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) { template <>
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
reinterpret_cast<c10::BFloat16 *>(&v); c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1); *ptr = *(v_ptr + 1);
} }
#ifndef __VEC_CLASS_FP_NAN #ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6) #define __VEC_CLASS_FP_NAN (1 << 6)
#endif #endif
const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13,
16, 17, 20, 21, 24, 25, 28, 29};
#ifndef _ARCH_PWR10 #ifndef _ARCH_PWR10
const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; 0x00007fff};
const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
const static __vector unsigned int one = { 1, 1, 1, 1 }; 0x7fc00000};
const static __vector unsigned int sh16 = {16, 16, 16, 16};
const static __vector unsigned int one = {1, 1, 1, 1};
#endif #endif
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
#ifdef _ARCH_PWR10 #ifdef _ARCH_PWR10
__vector signed short ret[2]; __vector signed short ret[2];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); (__vector unsigned char)v.reg.val[0]);
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[1]);
reg = vec_perm(ret[0], ret[1], omask); reg = vec_perm(ret[0], ret[1], omask);
#elif defined(_ARCH_PWR9) #elif defined(_ARCH_PWR9)
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
@ -425,8 +447,10 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
__vector unsigned int rnd1 = vec_add(lsb1, bias); __vector unsigned int rnd1 = vec_add(lsb1, bias);
inp0 = vec_add(inp0, rnd0); inp0 = vec_add(inp0, rnd0);
inp1 = vec_add(inp1, rnd1); inp1 = vec_add(inp1, rnd1);
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); __vector __bool int sel0 =
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
__vector __bool int sel1 =
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0); inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1); inp1 = vec_sel(inp1, nan, sel1);
inp0 = vec_sr(inp0, sh16); inp0 = vec_sr(inp0, sh16);
@ -435,13 +459,17 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
#endif #endif
} }
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
#ifdef _ARCH_PWR10 #ifdef _ARCH_PWR10
__vector signed short ret[4]; __vector signed short ret[4];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); (__vector unsigned char)v.reg.val[0]);
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); (__vector unsigned char)v.reg.val[1]);
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[2]);
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[3]);
reg.val[0] = vec_perm(ret[0], ret[1], omask); reg.val[0] = vec_perm(ret[0], ret[1], omask);
reg.val[1] = vec_perm(ret[2], ret[3], omask); reg.val[1] = vec_perm(ret[2], ret[3], omask);
#elif defined(_ARCH_PWR9) #elif defined(_ARCH_PWR9)
@ -465,10 +493,14 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
inp1 = vec_add(inp1, rnd1); inp1 = vec_add(inp1, rnd1);
inp2 = vec_add(inp2, rnd2); inp2 = vec_add(inp2, rnd2);
inp3 = vec_add(inp3, rnd3); inp3 = vec_add(inp3, rnd3);
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); __vector __bool int sel0 =
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
__vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); __vector __bool int sel1 =
__vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
__vector __bool int sel2 =
vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
__vector __bool int sel3 =
vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0); inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1); inp1 = vec_sel(inp1, nan, sel1);
inp2 = vec_sel(inp2, nan, sel2); inp2 = vec_sel(inp2, nan, sel2);
@ -482,10 +514,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
#endif #endif
} }
inline void prefetch(const void *addr) { inline void prefetch(const void* addr) {
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
} }
}; // namespace vec_op }; // namespace vec_op
#endif #endif

View File

@ -11,39 +11,40 @@ static_assert(false, "AVX2 must be supported for the current implementation.");
namespace vec_op { namespace vec_op {
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD #ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME) #define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME) #define CPU_KERNEL_GUARD_OUT(NAME)
#else #else
#define CPU_KERNEL_GUARD_IN(NAME) \ #define CPU_KERNEL_GUARD_IN(NAME) \
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({})); RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
#define CPU_KERNEL_GUARD_OUT(NAME) #define CPU_KERNEL_GUARD_OUT(NAME)
#endif #endif
#define FORCE_INLINE __attribute__((always_inline)) inline #define FORCE_INLINE __attribute__((always_inline)) inline
namespace { namespace {
template <typename T, T... indexes, typename F> template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) { constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...); (f(std::integral_constant<T, indexes>{}), ...);
} }
}; // namespace }; // namespace
template <typename T, T count, typename F, template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>> typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F &&f) { constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f)); unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
} }
template <typename T> struct Vec { template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
}; };
@ -55,12 +56,12 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
__m128i reg; __m128i reg;
explicit FP16Vec8(const void *ptr) explicit FP16Vec8(const void* ptr)
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} : reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
explicit FP16Vec8(const FP32Vec8 &); explicit FP16Vec8(const FP32Vec8&);
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
}; };
struct FP16Vec16 : public Vec<FP16Vec16> { struct FP16Vec16 : public Vec<FP16Vec16> {
@ -68,12 +69,12 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
__m256i reg; __m256i reg;
explicit FP16Vec16(const void *ptr) explicit FP16Vec16(const void* ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
explicit FP16Vec16(const FP32Vec16 &); explicit FP16Vec16(const FP32Vec16&);
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
void save(void* ptr, const int elem_num) const { void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF; constexpr uint32_t M = 0xFFFFFFFF;
@ -87,12 +88,12 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
__m128i reg; __m128i reg;
explicit BF16Vec8(const void *ptr) explicit BF16Vec8(const void* ptr)
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} : reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
explicit BF16Vec8(const FP32Vec8 &); explicit BF16Vec8(const FP32Vec8&);
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
}; };
struct BF16Vec16 : public Vec<BF16Vec16> { struct BF16Vec16 : public Vec<BF16Vec16> {
@ -100,12 +101,12 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
__m256i reg; __m256i reg;
explicit BF16Vec16(const void *ptr) explicit BF16Vec16(const void* ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
explicit BF16Vec16(const FP32Vec16 &); explicit BF16Vec16(const FP32Vec16&);
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
void save(void* ptr, const int elem_num) const { void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF; constexpr uint32_t M = 0xFFFFFFFF;
@ -120,11 +121,11 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
__m512i reg; __m512i reg;
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
explicit BF16Vec32(__m512i data) : reg(data) {} explicit BF16Vec32(__m512i data) : reg(data) {}
explicit BF16Vec32(BF16Vec8 &vec8_data) explicit BF16Vec32(BF16Vec8& vec8_data)
: reg((__m512i)_mm512_inserti32x4( : reg((__m512i)_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
(__m128i)vec8_data.reg), (__m128i)vec8_data.reg),
@ -132,7 +133,7 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
(__m128i)vec8_data.reg, 2), (__m128i)vec8_data.reg, 2),
(__m128i)vec8_data.reg, 3)) {} (__m128i)vec8_data.reg, 3)) {}
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<__m512i*>(ptr) = reg; }
}; };
#else #else
struct BF16Vec32 : public Vec<BF16Vec32> { struct BF16Vec32 : public Vec<BF16Vec32> {
@ -141,24 +142,24 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
__m256i reg_low; __m256i reg_low;
__m256i reg_high; __m256i reg_high;
explicit BF16Vec32(const void *ptr) explicit BF16Vec32(const void* ptr)
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)), : reg_low(_mm256_loadu_si256((__m256i const*)ptr)),
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} reg_high(_mm256_loadu_si256((__m256i const*)ptr + 1)) {}
explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), explicit BF16Vec32(__m256i low, __m256i high)
reg_high(high) {} : reg_low(low), reg_high(high) {}
explicit BF16Vec32(BF16Vec8 &vec8_data) explicit BF16Vec32(BF16Vec8& vec8_data)
: reg_low((__m256i)_mm256_inserti32x4( : reg_low((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg), _mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)), (__m128i)vec8_data.reg, 1)),
reg_high((__m256i)_mm256_inserti32x4( reg_high((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg), _mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)) {} (__m128i)vec8_data.reg, 1)) {}
void save(void *ptr) const { void save(void* ptr) const {
*reinterpret_cast<__m256i *>(ptr) = reg_low; *reinterpret_cast<__m256i*>(ptr) = reg_low;
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
} }
}; };
#endif #endif
@ -176,11 +177,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} explicit FP32Vec4(const float* ptr) : reg(_mm_loadu_ps(ptr)) {}
explicit FP32Vec4(__m128 data) : reg(data) {} explicit FP32Vec4(__m128 data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
}; };
struct FP32Vec8 : public Vec<FP32Vec8> { struct FP32Vec8 : public Vec<FP32Vec8> {
@ -196,15 +197,15 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} explicit FP32Vec8(const float* ptr) : reg(_mm256_loadu_ps(ptr)) {}
explicit FP32Vec8(__m256 data) : reg(data) {} explicit FP32Vec8(__m256 data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}
explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {} explicit FP32Vec8(const FP16Vec8& v) : reg(_mm256_cvtph_ps(v.reg)) {}
explicit FP32Vec8(const BF16Vec8 &v) explicit FP32Vec8(const BF16Vec8& v)
: reg(_mm256_castsi256_ps( : reg(_mm256_castsi256_ps(
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
@ -212,7 +213,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float result = 0; float result = 0;
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; }); unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result; return result;
} }
@ -244,27 +246,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
erf(ar.values[1]), erf(ar.values[0]))); erf(ar.values[1]), erf(ar.values[0])));
} }
FP32Vec8 operator*(const FP32Vec8 &b) const { FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(_mm256_mul_ps(reg, b.reg)); return FP32Vec8(_mm256_mul_ps(reg, b.reg));
} }
FP32Vec8 operator+(const FP32Vec8 &b) const { FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(_mm256_add_ps(reg, b.reg)); return FP32Vec8(_mm256_add_ps(reg, b.reg));
} }
FP32Vec8 operator-(const FP32Vec8 &b) const { FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(_mm256_sub_ps(reg, b.reg)); return FP32Vec8(_mm256_sub_ps(reg, b.reg));
} }
FP32Vec8 operator/(const FP32Vec8 &b) const { FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(_mm256_div_ps(reg, b.reg)); return FP32Vec8(_mm256_div_ps(reg, b.reg));
} }
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } void save(float* ptr) const { _mm256_storeu_ps(ptr, reg); }
}; };
#ifdef __AVX512F__ #ifdef __AVX512F__
struct INT32Vec16: public Vec<INT32Vec16> { struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16; constexpr static int VEC_ELEM_NUM = 16;
union AliasReg { union AliasReg {
__m512i reg; __m512i reg;
@ -272,12 +274,11 @@ struct INT32Vec16: public Vec<INT32Vec16> {
}; };
__m512i reg; __m512i reg;
explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {}
void save(int32_t* ptr) const { explicit INT32Vec16(const void* data_ptr)
_mm512_storeu_epi32(ptr, reg); : reg(_mm512_loadu_epi32(data_ptr)) {}
}
void save(int32_t* ptr) const { _mm512_storeu_epi32(ptr, reg); }
void save(int32_t* ptr, const int elem_num) const { void save(int32_t* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF; constexpr uint32_t M = 0xFFFFFFFF;
@ -301,11 +302,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
explicit FP32Vec16(__m512 data) : reg(data) {} explicit FP32Vec16(__m512 data) : reg(data) {}
explicit FP32Vec16(const FP32Vec4 &data) explicit FP32Vec16(const FP32Vec4& data)
: reg((__m512)_mm512_inserti32x4( : reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4( _mm512_inserti32x4(
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
@ -313,36 +314,37 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
(__m128i)data.reg, 2), (__m128i)data.reg, 2),
(__m128i)data.reg, 3)) {} (__m128i)data.reg, 3)) {}
explicit FP32Vec16(const FP32Vec8 &data) explicit FP32Vec16(const FP32Vec8& data)
: reg((__m512)_mm512_inserti32x8( : reg((__m512)_mm512_inserti32x8(
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
explicit FP32Vec16(const BF16Vec16 &v) explicit FP32Vec16(const BF16Vec16& v)
: reg(_mm512_castsi512_ps( : reg(_mm512_castsi512_ps(
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {} explicit FP32Vec16(const FP16Vec16& v) : reg(_mm512_cvtph_ps(v.reg)) {}
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16 &v) explicit FP32Vec16(const INT32Vec16& v)
: reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {} : reg(_mm512_cvt_roundepi32_ps(
v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const { FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(_mm512_mul_ps(reg, b.reg)); return FP32Vec16(_mm512_mul_ps(reg, b.reg));
} }
FP32Vec16 operator+(const FP32Vec16 &b) const { FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(_mm512_add_ps(reg, b.reg)); return FP32Vec16(_mm512_add_ps(reg, b.reg));
} }
FP32Vec16 operator-(const FP32Vec16 &b) const { FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(_mm512_sub_ps(reg, b.reg)); return FP32Vec16(_mm512_sub_ps(reg, b.reg));
} }
FP32Vec16 operator/(const FP32Vec16 &b) const { FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(_mm512_div_ps(reg, b.reg)); return FP32Vec16(_mm512_div_ps(reg, b.reg));
} }
@ -370,9 +372,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg)); return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg));
} }
FP32Vec16 abs() const { FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); }
return FP32Vec16(_mm512_abs_ps(reg));
}
float reduce_sum() const { return _mm512_reduce_add_ps(reg); } float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
@ -380,14 +380,15 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float reduce_min() const { return _mm512_reduce_min_ps(reg); } float reduce_min() const { return _mm512_reduce_min_ps(reg); }
template <int group_size> float reduce_sub_sum(int idx) { template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0); static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
return _mm512_mask_reduce_add_ps(mask, reg); return _mm512_mask_reduce_add_ps(mask, reg);
} }
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
void save(float* ptr, const int elem_num) const { void save(float* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF; constexpr uint32_t M = 0xFFFFFFFF;
@ -407,32 +408,30 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
__m256 reg_low; __m256 reg_low;
__m256 reg_high; __m256 reg_high;
explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), explicit FP32Vec16(float v)
reg_high(_mm256_set1_ps(v)) {} : reg_low(_mm256_set1_ps(v)), reg_high(_mm256_set1_ps(v)) {}
explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), explicit FP32Vec16()
reg_high(_mm256_set1_ps(0.0)) {} : reg_low(_mm256_set1_ps(0.0)), reg_high(_mm256_set1_ps(0.0)) {}
explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), explicit FP32Vec16(const float* ptr)
reg_high(_mm256_loadu_ps(ptr + 8)) {} : reg_low(_mm256_loadu_ps(ptr)), reg_high(_mm256_loadu_ps(ptr + 8)) {}
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), explicit FP32Vec16(const FP32Vec16& data)
reg_high(data.reg_high) {} : reg_low(data.reg_low), reg_high(data.reg_high) {}
explicit FP32Vec16(const FP32Vec4 &data) explicit FP32Vec16(const FP32Vec4& data)
: reg_low((__m256)_mm256_inserti128_si256( : reg_low((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg), _mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
(__m128i)data.reg, 1)),
reg_high((__m256)_mm256_inserti128_si256( reg_high((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg), _mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)) {}
(__m128i)data.reg, 1)) {}
explicit FP32Vec16(const FP32Vec8 &data) explicit FP32Vec16(const FP32Vec8& data)
: reg_low(data.reg), reg_high(data.reg) {} : reg_low(data.reg), reg_high(data.reg) {}
explicit FP32Vec16(const FP16Vec16 &v) { explicit FP32Vec16(const FP16Vec16& v) {
__m128i low = _mm256_extractf128_si256(v.reg, 0); __m128i low = _mm256_extractf128_si256(v.reg, 0);
__m128i high = _mm256_extractf128_si256(v.reg, 1); __m128i high = _mm256_extractf128_si256(v.reg, 1);
@ -440,9 +439,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg_high = _mm256_cvtph_ps(high); reg_high = _mm256_cvtph_ps(high);
} }
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const BF16Vec16 &v) { explicit FP32Vec16(const BF16Vec16& v) {
__m128i low = _mm256_extractf128_si256(v.reg, 0); __m128i low = _mm256_extractf128_si256(v.reg, 0);
__m128i high = _mm256_extractf128_si256(v.reg, 1); __m128i high = _mm256_extractf128_si256(v.reg, 1);
@ -456,24 +455,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg_high = _mm256_castsi256_ps(v_high_shifted); reg_high = _mm256_castsi256_ps(v_high_shifted);
} }
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const { FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
_mm256_mul_ps(reg_high, b.reg_high)); _mm256_mul_ps(reg_high, b.reg_high));
} }
FP32Vec16 operator+(const FP32Vec16 &b) const { FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
_mm256_add_ps(reg_high, b.reg_high)); _mm256_add_ps(reg_high, b.reg_high));
} }
FP32Vec16 operator-(const FP32Vec16 &b) const { FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
_mm256_sub_ps(reg_high, b.reg_high)); _mm256_sub_ps(reg_high, b.reg_high));
} }
FP32Vec16 operator/(const FP32Vec16 &b) const { FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
_mm256_div_ps(reg_high, b.reg_high)); _mm256_div_ps(reg_high, b.reg_high));
} }
@ -484,7 +483,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return low.reduce_sum() + high.reduce_sum(); return low.reduce_sum() + high.reduce_sum();
} }
template <int group_size> float reduce_sub_sum(int idx) { template <int group_size>
float reduce_sub_sum(int idx) {
float sum = 0.0; float sum = 0.0;
static_assert(VEC_ELEM_NUM % group_size == 0); static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
@ -507,7 +507,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return sum; return sum;
} }
void save(float *ptr) const { void save(float* ptr) const {
_mm256_storeu_ps(ptr, reg_low); _mm256_storeu_ps(ptr, reg_low);
_mm256_storeu_ps(ptr + 8, reg_high); _mm256_storeu_ps(ptr + 8, reg_high);
} }
@ -515,7 +515,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
#endif #endif
#ifdef __AVX512F__ #ifdef __AVX512F__
struct INT8Vec16: public Vec<INT8Vec16> { struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_ELEM_NUM = 16; constexpr static int VEC_ELEM_NUM = 16;
union AliasReg { union AliasReg {
__m128i reg; __m128i reg;
@ -523,14 +523,12 @@ struct INT8Vec16: public Vec<INT8Vec16> {
}; };
__m128i reg; __m128i reg;
explicit INT8Vec16(const FP32Vec16& vec) : reg(
_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
) {}
void save(int8_t* ptr) const { explicit INT8Vec16(const FP32Vec16& vec)
_mm_storeu_epi8(ptr, reg); : reg(_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(
} vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))) {}
void save(int8_t* ptr) const { _mm_storeu_epi8(ptr, reg); }
void save(int8_t* ptr, const int elem_num) const { void save(int8_t* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF; constexpr uint32_t M = 0xFFFFFFFF;
@ -540,71 +538,92 @@ struct INT8Vec16: public Vec<INT8Vec16> {
}; };
#endif #endif
template <typename T> struct VecType { using vec_type = void; }; template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T> using vec_t = typename VecType<T>::vec_type; template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <> struct VecType<float> { using vec_type = FP32Vec8; }; template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; }; template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; }; template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; } template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc = acc + a * b; acc = acc + a * b;
} }
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) { template <>
*reinterpret_cast<unsigned short *>(ptr) = inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
*reinterpret_cast<unsigned short*>(ptr) =
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); _cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
} }
inline FP16Vec8::FP16Vec8(const FP32Vec8 &v) inline FP16Vec8::FP16Vec8(const FP32Vec8& v)
: reg(_mm256_cvtps_ph(v.reg, : reg(_mm256_cvtps_ph(v.reg,
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
#ifdef __AVX512F__ #ifdef __AVX512F__
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
: reg(_mm512_cvtps_ph(v.reg, : reg(_mm512_cvtps_ph(v.reg,
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
#else #else
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
: reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {} : reg(_mm256_insertf128_si256(
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
#endif #endif
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) { template <>
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
*reinterpret_cast<__bfloat16*>(ptr) = _mm_cvtness_sbh(v);
} }
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
} }
#else #else
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) { template <>
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
reinterpret_cast<c10::BFloat16 *>(&v); c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1); *ptr = *(v_ptr + 1);
} }
#ifdef __AVX512F__ #ifdef __AVX512F__
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg(_mm256_cvtepi32_epi16( : reg(_mm256_cvtepi32_epi16(
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
: reg(_mm512_cvtepi32_epi16( : reg(_mm512_cvtepi32_epi16(
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
#else #else
namespace{ namespace {
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
__m256i ai = _mm256_castps_si256(a); __m256i ai = _mm256_castps_si256(a);
ai = _mm256_srli_epi32(ai, 16); ai = _mm256_srli_epi32(ai, 16);
@ -612,21 +631,21 @@ __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
ai = _mm256_permute4x64_epi64(ai, 0b00111001); ai = _mm256_permute4x64_epi64(ai, 0b00111001);
return _mm256_extracti128_si256(ai, 0); return _mm256_extracti128_si256(ai, 0);
} }
} } // namespace
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
} }
#endif // __AVX512F__ #endif // __AVX512F__
#endif // __AVX512BF16__ #endif // __AVX512BF16__
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
}; // namespace vec_op }; // namespace vec_op
#endif #endif

View File

@ -27,8 +27,7 @@
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0; int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
cudaDevAttrMaxSharedMemoryPerBlockOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
device);
return max_shared_mem_per_block_opt_in; return max_shared_mem_per_block_opt_in;
} }

View File

@ -25,10 +25,12 @@ Check out the [building from source](#build-from-source) documentation for detai
```bash ```bash
pip install -r requirements-dev.txt pip install -r requirements-dev.txt
# linting and formatting # Linting, formatting and static type checking
bash format.sh pre-commit install
# Static type checking
mypy # You can manually run pre-commit with
pre-commit run --all-files
# Unit tests # Unit tests
pytest tests/ pytest tests/
``` ```
@ -88,7 +90,8 @@ If the PR spans more than one category, please include all relevant prefixes.
The PR needs to meet the following code quality standards: The PR needs to meet the following code quality standards:
- We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html). - We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
- Pass all linter checks. Please use <gh-file:format.sh> to format your code. - Pass all linter checks. Please use `pre-commit` to format your code. See
<https://pre-commit.com/#usage> if `pre-commit` is new to you.
- The code needs to be well-documented to ensure future contributors can easily - The code needs to be well-documented to ensure future contributors can easily
understand the code. understand the code.
- Include sufficient tests to ensure the project stays correct and robust. This - Include sufficient tests to ensure the project stays correct and robust. This

321
format.sh
View File

@ -1,321 +0,0 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash format.sh
# # Commit changed files with message 'Run yapf and ruff'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
check_command() {
if ! command -v "$1" &> /dev/null; then
echo "❓❓$1 is not installed, please run \`pip install -r requirements-lint.txt\`"
exit 1
fi
}
check_command yapf
check_command ruff
check_command mypy
check_command codespell
check_command isort
check_command clang-format
YAPF_VERSION=$(yapf --version | awk '{print $2}')
RUFF_VERSION=$(ruff --version | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
CODESPELL_VERSION=$(codespell --version)
ISORT_VERSION=$(isort --vn)
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
PYMARKDOWNLNT_VERSION=$(pymarkdownlnt version | awk '{print $1}')
# # params: tool name, tool version, required version
tool_version_check() {
expected=$(grep "$1" requirements-lint.txt | cut -d'=' -f3)
if [[ "$2" != "$expected" ]]; then
echo "❓❓Wrong $1 version installed: $expected is required, not $2."
exit 1
fi
}
tool_version_check "yapf" "$YAPF_VERSION"
tool_version_check "ruff" "$RUFF_VERSION"
tool_version_check "mypy" "$MYPY_VERSION"
tool_version_check "isort" "$ISORT_VERSION"
tool_version_check "codespell" "$CODESPELL_VERSION"
tool_version_check "clang-format" "$CLANGFORMAT_VERSION"
tool_version_check "pymarkdownlnt" "$PYMARKDOWNLNT_VERSION"
YAPF_FLAGS=(
'--recursive'
'--parallel'
)
YAPF_EXCLUDES=(
'--exclude' 'build/**'
)
# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}
# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
}
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" .
}
## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'vLLM yapf: Done'
# Run mypy
echo 'vLLM mypy:'
tools/mypy.sh
echo 'vLLM mypy: Done'
# If git diff returns a file that is in the skip list, the file may be checked anyway:
# https://github.com/codespell-project/codespell/issues/1915
# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem
CODESPELL_EXCLUDES=(
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**'
)
# check spelling of specified files
spell_check() {
codespell "$@"
}
spell_check_all(){
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}"
}
# Spelling check of files that differ from main branch.
spell_check_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause ruff to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
codespell "${CODESPELL_EXCLUDES[@]}"
fi
}
# Run Codespell
## This flag runs spell check of individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
spell_check "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is linted.
elif [[ "$1" == '--all' ]]; then
spell_check_all
else
# Check spelling only of the files that changed in last commit.
spell_check_changed
fi
echo 'vLLM codespell: Done'
# Lint specified files
lint() {
ruff check "$@"
}
# Lint files that differ from main branch. Ignores dirs that are not slated
# for autolint yet.
lint_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause ruff to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
ruff check
fi
}
# Run Ruff
### This flag lints individual files. --files *must* be the first command line
### arg to use this option.
if [[ "$1" == '--files' ]]; then
lint "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is linted.
elif [[ "$1" == '--all' ]]; then
lint vllm tests
else
# Format only the files that changed in last commit.
lint_changed
fi
echo 'vLLM ruff: Done'
# check spelling of specified files
isort_check() {
isort "$@"
}
isort_check_all(){
isort .
}
# Spelling check of files that differ from main branch.
isort_check_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause ruff to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
isort
fi
}
# Run Isort
# This flag runs spell check of individual files. --files *must* be the first command line
# arg to use this option.
if [[ "$1" == '--files' ]]; then
isort_check "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is linted.
elif [[ "$1" == '--all' ]]; then
isort_check_all
else
# Check spelling only of the files that changed in last commit.
isort_check_changed
fi
echo 'vLLM isort: Done'
# Clang-format section
# Exclude some files for formatting because they are vendored
# NOTE: Keep up to date with .github/workflows/clang-format.yml
CLANG_FORMAT_EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/quantization/gguf/ggml-common.h'
'csrc/quantization/gguf/dequantize.cuh'
'csrc/quantization/gguf/vecdotq.cuh'
'csrc/quantization/gguf/mmq.cuh'
'csrc/quantization/gguf/mmvq.cuh'
)
# Format specified files with clang-format
clang_format() {
clang-format -i "$@"
}
# Format files that differ from main branch with clang-format.
clang_format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause clang-format to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
# Get the list of changed files, excluding the specified ones
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | (grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") || echo -e))
if [ -n "$changed_files" ]; then
echo "$changed_files" | xargs -P 5 clang-format -i
fi
}
# Format all files with clang-format
clang_format_all() {
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
| xargs clang-format -i
}
# Run clang-format
if [[ "$1" == '--files' ]]; then
clang_format "${@:2}"
elif [[ "$1" == '--all' ]]; then
clang_format_all
else
clang_format_changed
fi
echo 'vLLM clang-format: Done'
echo 'vLLM actionlint:'
tools/actionlint.sh -color
echo 'vLLM actionlint: Done'
echo 'vLLM shellcheck:'
tools/shellcheck.sh
echo 'vLLM shellcheck: Done'
echo 'excalidraw png check:'
tools/png-lint.sh
echo 'excalidraw png check: Done'
if ! git diff --quiet &>/dev/null; then
echo
echo "🔍🔍There are files changed by the format checker or by you that are not added and committed:"
git --no-pager diff --name-only
echo "🔍🔍Format checker passed, but please add, commit and push all the files above to include changes made by the format checker."
exit 1
else
echo "✨🎉 Format check passed! Congratulations! 🎉✨"
fi
echo 'vLLM doc-lint:'
tools/doc-lint.sh
echo 'vLLM doc-lint: Done'

View File

@ -15,6 +15,11 @@ build-backend = "setuptools.build_meta"
[tool.setuptools_scm] [tool.setuptools_scm]
# version_file = "vllm/_version.py" # currently handled by `setup.py:get_version()` # version_file = "vllm/_version.py" # currently handled by `setup.py:get_version()`
[tool.yapfignore]
ignore_patterns = [
"build/**",
]
[tool.ruff] [tool.ruff]
# Allow lines to be as long as 80. # Allow lines to be as long as 80.
line-length = 80 line-length = 80
@ -52,6 +57,9 @@ ignore = [
"B007", "B007",
# f-string format # f-string format
"UP032", "UP032",
# Python 3.8 typing
"UP006", "UP035",
] ]
[tool.mypy] [tool.mypy]

View File

@ -1,15 +1,2 @@
# formatting # formatting
yapf==0.32.0 pre-commit==4.0.1
toml==0.10.2
tomli==2.0.2
ruff==0.6.5
codespell==2.3.0
isort==5.13.2
clang-format==18.1.5
pymarkdownlnt==0.9.26
# type checking
mypy==1.11.1
types-PyYAML
types-requests
types-setuptools

View File

@ -1,13 +0,0 @@
#!/bin/bash
if command -v actionlint &> /dev/null; then
actionlint "$@"
exit 0
elif [ -x ./actionlint ]; then
./actionlint "$@"
exit 0
fi
# download a binary to the current directory - v1.7.3
bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/aa0a7be8e566b096e64a5df8ff290ec24fa58fbc/scripts/download-actionlint.bash)
./actionlint "$@"

View File

@ -1,3 +0,0 @@
#!/bin/bash
pymarkdownlnt scan docs -r