Move linting to pre-commit
(#11975)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
51ef828f10
commit
3ea7b94523
@ -43,7 +43,7 @@ main() {
|
||||
|
||||
|
||||
|
||||
# The figures should be genereated by a separate process outside the CI/CD pipeline
|
||||
# The figures should be generated by a separate process outside the CI/CD pipeline
|
||||
|
||||
# # generate figures
|
||||
# python3 -m pip install tabulate pandas matplotlib
|
||||
|
40
.github/workflows/actionlint.yml
vendored
40
.github/workflows/actionlint.yml
vendored
@ -1,40 +0,0 @@
|
||||
name: Lint GitHub Actions workflows
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '.github/workflows/*.ya?ml'
|
||||
- '.github/workflows/actionlint.*'
|
||||
- '.github/workflows/matchers/actionlint.json'
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '.github/workflows/*.ya?ml'
|
||||
- '.github/workflows/actionlint.*'
|
||||
- '.github/workflows/matchers/actionlint.json'
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
actionlint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "Checkout"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: "Run actionlint"
|
||||
run: |
|
||||
echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
tools/actionlint.sh -color
|
53
.github/workflows/clang-format.yml
vendored
53
.github/workflows/clang-format.yml
vendored
@ -1,53 +0,0 @@
|
||||
name: clang-format
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**/*.h'
|
||||
- '**/*.cpp'
|
||||
- '**/*.cu'
|
||||
- '**/*.cuh'
|
||||
- '.github/workflows/clang-format.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**/*.h'
|
||||
- '**/*.cpp'
|
||||
- '**/*.cu'
|
||||
- '**/*.cuh'
|
||||
- '.github/workflows/clang-format.yml'
|
||||
|
||||
jobs:
|
||||
clang-format:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install clang-format==18.1.5
|
||||
- name: Running clang-format
|
||||
run: |
|
||||
EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/quantization/gguf/ggml-common.h'
|
||||
'csrc/quantization/gguf/dequantize.cuh'
|
||||
'csrc/quantization/gguf/vecdotq.cuh'
|
||||
'csrc/quantization/gguf/mmq.cuh'
|
||||
'csrc/quantization/gguf/mmvq.cuh'
|
||||
)
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||
| xargs clang-format --dry-run --Werror
|
45
.github/workflows/codespell.yml
vendored
45
.github/workflows/codespell.yml
vendored
@ -1,45 +0,0 @@
|
||||
name: codespell
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "**/*.md"
|
||||
- "**/*.rst"
|
||||
- pyproject.toml
|
||||
- requirements-lint.txt
|
||||
- .github/workflows/codespell.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "**/*.md"
|
||||
- "**/*.rst"
|
||||
- pyproject.toml
|
||||
- requirements-lint.txt
|
||||
- .github/workflows/codespell.yml
|
||||
|
||||
jobs:
|
||||
codespell:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-lint.txt
|
||||
- name: Spelling check with codespell
|
||||
run: |
|
||||
codespell --toml pyproject.toml
|
32
.github/workflows/doc-lint.yml
vendored
32
.github/workflows/doc-lint.yml
vendored
@ -1,32 +0,0 @@
|
||||
name: Lint documentation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docs/**"
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docs/**"
|
||||
|
||||
jobs:
|
||||
doc-lint:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-lint.txt
|
||||
- name: Linting docs
|
||||
run: tools/doc-lint.sh
|
20
.github/workflows/dummy.yml
vendored
Normal file
20
.github/workflows/dummy.yml
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
name: dummy-checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
mypy:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- run: echo "This is a dummy step that always passes"
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- run: echo "This is a dummy step that always passes"
|
17
.github/workflows/matchers/ruff.json
vendored
17
.github/workflows/matchers/ruff.json
vendored
@ -1,17 +0,0 @@
|
||||
{
|
||||
"problemMatcher": [
|
||||
{
|
||||
"owner": "ruff",
|
||||
"pattern": [
|
||||
{
|
||||
"regexp": "^(.+?):(\\d+):(\\d+): (\\w+): (.+)$",
|
||||
"file": 1,
|
||||
"line": 2,
|
||||
"column": 3,
|
||||
"code": 4,
|
||||
"message": 5
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
51
.github/workflows/mypy.yaml
vendored
51
.github/workflows/mypy.yaml
vendored
@ -1,51 +0,0 @@
|
||||
name: mypy
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**/*.py'
|
||||
- '.github/workflows/mypy.yaml'
|
||||
- 'tools/mypy.sh'
|
||||
- 'pyproject.toml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
# This workflow is only relevant when one of the following files changes.
|
||||
# However, we have github configured to expect and require this workflow
|
||||
# to run and pass before github with auto-merge a pull request. Until github
|
||||
# allows more flexible auto-merge policy, we can just run this on every PR.
|
||||
# It doesn't take that long to run, anyway.
|
||||
#paths:
|
||||
# - '**/*.py'
|
||||
# - '.github/workflows/mypy.yaml'
|
||||
# - 'tools/mypy.sh'
|
||||
# - 'pyproject.toml'
|
||||
|
||||
jobs:
|
||||
mypy:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install mypy==1.11.1
|
||||
pip install types-setuptools
|
||||
pip install types-PyYAML
|
||||
pip install types-requests
|
||||
pip install types-setuptools
|
||||
- name: Mypy
|
||||
run: |
|
||||
echo "::add-matcher::.github/workflows/matchers/mypy.json"
|
||||
tools/mypy.sh 1 ${{ matrix.python-version }}
|
37
.github/workflows/png-lint.yml
vendored
37
.github/workflows/png-lint.yml
vendored
@ -1,37 +0,0 @@
|
||||
name: Lint PNG exports from excalidraw
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '*.excalidraw.png'
|
||||
- '.github/workflows/png-lint.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '*.excalidraw.png'
|
||||
- '.github/workflows/png-lint.yml'
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
actionlint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "Checkout"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: "Run png-lint.sh to check excalidraw exported images"
|
||||
run: |
|
||||
tools/png-lint.sh
|
17
.github/workflows/pre-commit.yml
vendored
Normal file
17
.github/workflows/pre-commit.yml
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
52
.github/workflows/ruff.yml
vendored
52
.github/workflows/ruff.yml
vendored
@ -1,52 +0,0 @@
|
||||
name: ruff
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- pyproject.toml
|
||||
- requirements-lint.txt
|
||||
- .github/workflows/matchers/ruff.json
|
||||
- .github/workflows/ruff.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
# This workflow is only relevant when one of the following files changes.
|
||||
# However, we have github configured to expect and require this workflow
|
||||
# to run and pass before github with auto-merge a pull request. Until github
|
||||
# allows more flexible auto-merge policy, we can just run this on every PR.
|
||||
# It doesn't take that long to run, anyway.
|
||||
#paths:
|
||||
# - "**/*.py"
|
||||
# - pyproject.toml
|
||||
# - requirements-lint.txt
|
||||
# - .github/workflows/matchers/ruff.json
|
||||
# - .github/workflows/ruff.yml
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-lint.txt
|
||||
- name: Analysing the code with ruff
|
||||
run: |
|
||||
echo "::add-matcher::.github/workflows/matchers/ruff.json"
|
||||
ruff check --output-format github .
|
||||
- name: Run isort
|
||||
run: |
|
||||
isort . --check-only
|
37
.github/workflows/shellcheck.yml
vendored
37
.github/workflows/shellcheck.yml
vendored
@ -1,37 +0,0 @@
|
||||
name: Lint shell scripts
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**/*.sh'
|
||||
- '.github/workflows/shellcheck.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**/*.sh'
|
||||
- '.github/workflows/shellcheck.yml'
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
shellcheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "Checkout"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: "Check shell scripts"
|
||||
run: |
|
||||
tools/shellcheck.sh
|
38
.github/workflows/yapf.yml
vendored
38
.github/workflows/yapf.yml
vendored
@ -1,38 +0,0 @@
|
||||
name: yapf
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- .github/workflows/yapf.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- .github/workflows/yapf.yml
|
||||
|
||||
jobs:
|
||||
yapf:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install yapf==0.32.0
|
||||
pip install toml==0.10.2
|
||||
- name: Running yapf
|
||||
run: |
|
||||
yapf --diff --recursive .
|
73
.pre-commit-config.yaml
Normal file
73
.pre-commit-config.yaml
Normal file
@ -0,0 +1,73 @@
|
||||
repos:
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.32.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github]
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.3.0
|
||||
hooks:
|
||||
- id: codespell
|
||||
exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*'
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.5
|
||||
hooks:
|
||||
- id: clang-format
|
||||
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))'
|
||||
types_or: [c++, cuda]
|
||||
args: [--style=file, --verbose]
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.27
|
||||
hooks:
|
||||
- id: pymarkdown
|
||||
files: docs/.*
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
entry: tools/mypy.sh 1 "3.9"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.10
|
||||
entry: tools/mypy.sh 1 "3.10"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.11
|
||||
entry: tools/mypy.sh 1 "3.11"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.12
|
||||
entry: tools/mypy.sh 1 "3.12"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
- id: shellcheck
|
||||
name: Lint shell scripts
|
||||
entry: tools/shellcheck.sh
|
||||
language: script
|
||||
types: [shell]
|
||||
- id: png-lint
|
||||
name: Lint PNG exports from excalidraw
|
||||
entry: tools/png-lint.sh
|
||||
language: script
|
||||
types: [png]
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: v1.7.6
|
||||
hooks:
|
||||
- id: actionlint
|
@ -32,7 +32,7 @@ class ScalarType {
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr){};
|
||||
nan_repr(nan_repr) {};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
|
@ -2,13 +2,13 @@
|
||||
#define CPU_TYPES_HPP
|
||||
|
||||
#if defined(__x86_64__)
|
||||
//x86 implementation
|
||||
// x86 implementation
|
||||
#include "cpu_types_x86.hpp"
|
||||
#elif defined(__POWER9_VECTOR__)
|
||||
//ppc implementation
|
||||
// ppc implementation
|
||||
#include "cpu_types_vsx.hpp"
|
||||
#elif defined(__aarch64__)
|
||||
//arm implementation
|
||||
// arm implementation
|
||||
#include "cpu_types_arm.hpp"
|
||||
#else
|
||||
#warning "unsupported vLLM cpu implementation"
|
||||
|
@ -5,44 +5,46 @@
|
||||
namespace vec_op {
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
#else
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||
std::cout << #NAME << " exit." << std::endl;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
};
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
};
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
|
||||
};
|
||||
|
||||
@ -54,127 +56,124 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||
|
||||
float16x8_t reg;
|
||||
|
||||
explicit FP16Vec8(const void *ptr)
|
||||
: reg(vld1q_f16(static_cast<const __fp16 *>(ptr))) {};
|
||||
explicit FP16Vec8(const void* ptr)
|
||||
: reg(vld1q_f16(static_cast<const __fp16*>(ptr))) {};
|
||||
|
||||
explicit FP16Vec8(const FP32Vec8 &);
|
||||
explicit FP16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const {
|
||||
vst1q_f16(static_cast<__fp16 *>(ptr), reg);
|
||||
}
|
||||
void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); }
|
||||
};
|
||||
|
||||
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
float16x8x2_t reg;
|
||||
float16x8x2_t reg;
|
||||
|
||||
explicit FP16Vec16(const void *ptr) {
|
||||
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
|
||||
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
||||
}
|
||||
explicit FP16Vec16(const void* ptr) {
|
||||
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
|
||||
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
||||
}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16& vec);
|
||||
explicit FP16Vec16(const FP32Vec16& vec);
|
||||
|
||||
void save(void *ptr) const {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
void save(void* ptr) const {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
if (full_blocks > 1) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
}
|
||||
|
||||
void save(void *ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
// Note: below is the unrolled version of the following code:
|
||||
//
|
||||
// for (int i = 0; i < remainder; ++i) {
|
||||
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] =
|
||||
// vgetq_lane_f16(temp, i);
|
||||
// }
|
||||
//
|
||||
// For macOS build (Clang), the arm/neon intrinsics function
|
||||
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
|
||||
// time.
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
if (full_blocks > 1) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
}
|
||||
if (remainder > 0) {
|
||||
float16x8_t temp = reg.val[full_blocks];
|
||||
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
|
||||
switch (remainder) {
|
||||
case 1:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
break;
|
||||
case 2:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
break;
|
||||
case 3:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
break;
|
||||
case 4:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
break;
|
||||
case 5:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
break;
|
||||
case 6:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
break;
|
||||
case 7:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
|
||||
break;
|
||||
|
||||
// Note: below is the unrolled version of the following code:
|
||||
//
|
||||
// for (int i = 0; i < remainder; ++i) {
|
||||
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] =
|
||||
// vgetq_lane_f16(temp, i);
|
||||
// }
|
||||
//
|
||||
// For macOS build (Clang), the arm/neon intrinsics function
|
||||
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
|
||||
// time.
|
||||
|
||||
if (remainder > 0) {
|
||||
float16x8_t temp = reg.val[full_blocks];
|
||||
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
|
||||
switch (remainder)
|
||||
{
|
||||
case 1:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
break;
|
||||
case 2:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
break;
|
||||
case 3:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
break;
|
||||
case 4:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
break;
|
||||
case 5:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
break;
|
||||
case 6:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
break;
|
||||
case 7:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
|
||||
bfloat16x8_t reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8_t *>(ptr)) {};
|
||||
explicit BF16Vec8(const void* ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8_t*>(ptr)) {};
|
||||
|
||||
explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
|
||||
explicit BF16Vec8(float32x4x2_t v)
|
||||
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8_t *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8_t*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
@ -182,19 +181,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
bfloat16x8x2_t reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x2_t *>(ptr)) {};
|
||||
explicit BF16Vec16(const void* ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {};
|
||||
|
||||
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
explicit BF16Vec16(float32x4x4_t v) : reg({
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])
|
||||
}){};
|
||||
explicit BF16Vec16(float32x4x4_t v)
|
||||
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x2_t *>(ptr) = reg; };
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||
};
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -202,19 +200,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
bfloat16x8x4_t reg;
|
||||
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x4_t *>(ptr)) {};
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x4_t*>(ptr)) {};
|
||||
|
||||
explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg
|
||||
}) {};
|
||||
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x4_t *>(ptr) = reg; };
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -232,11 +226,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
|
||||
explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {};
|
||||
explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {};
|
||||
|
||||
explicit FP32Vec4(float32x4_t data) : reg(data) {};
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {};
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
@ -252,32 +246,37 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
|
||||
explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
|
||||
|
||||
explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
|
||||
explicit FP32Vec8(const float* ptr)
|
||||
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
|
||||
|
||||
explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {};
|
||||
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
|
||||
|
||||
explicit FP32Vec8(const FP16Vec8 &v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
|
||||
};
|
||||
explicit FP32Vec8(const FP16Vec8& v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
|
||||
};
|
||||
|
||||
explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
|
||||
explicit FP32Vec8(float16x8_t v)
|
||||
: reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
|
||||
explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
|
||||
explicit FP32Vec8(bfloat16x8_t v)
|
||||
: reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
|
||||
explicit FP32Vec8(const BF16Vec8& v)
|
||||
: reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float answer = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
|
||||
return answer;
|
||||
}
|
||||
@ -324,10 +323,14 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
|
||||
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])), static_cast<float32_t>(erf(ar.values[1]))};
|
||||
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])), static_cast<float32_t>(erf(ar.values[3]))};
|
||||
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])), static_cast<float32_t>(erf(ar.values[5]))};
|
||||
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])), static_cast<float32_t>(erf(ar.values[7]))};
|
||||
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])),
|
||||
static_cast<float32_t>(erf(ar.values[1]))};
|
||||
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])),
|
||||
static_cast<float32_t>(erf(ar.values[3]))};
|
||||
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])),
|
||||
static_cast<float32_t>(erf(ar.values[5]))};
|
||||
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])),
|
||||
static_cast<float32_t>(erf(ar.values[7]))};
|
||||
|
||||
float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
|
||||
float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
|
||||
@ -339,23 +342,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
return FP32Vec8(result);
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||
vmulq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||
vsubq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||
vdivq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vst1q_f32(ptr, reg.val[0]);
|
||||
vst1q_f32(ptr + 4, reg.val[1]);
|
||||
}
|
||||
@ -370,103 +377,100 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float32x4x4_t reg;
|
||||
|
||||
explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
|
||||
explicit FP32Vec16(float v)
|
||||
: reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
|
||||
|
||||
explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}
|
||||
explicit FP32Vec16()
|
||||
: reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0),
|
||||
vmovq_n_f32(0.0)}) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {}
|
||||
explicit FP32Vec16(const float* ptr)
|
||||
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
|
||||
vld1q_f32(ptr + 12)}) {}
|
||||
|
||||
explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
explicit FP32Vec16(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {}
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {}
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(bfloat16x8x2_t v) : reg({
|
||||
vcvtq_low_f32_bf16(v.val[0]),
|
||||
vcvtq_high_f32_bf16(v.val[0]),
|
||||
vcvtq_low_f32_bf16(v.val[1]),
|
||||
vcvtq_high_f32_bf16(v.val[1])
|
||||
}) {};
|
||||
#endif
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(bfloat16x8x2_t v)
|
||||
: reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]),
|
||||
vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {};
|
||||
#endif
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data) {
|
||||
explicit FP32Vec16(const FP32Vec4& data) {
|
||||
reg.val[0] = data.reg;
|
||||
reg.val[1] = data.reg;
|
||||
reg.val[2] = data.reg;
|
||||
reg.val[3] = data.reg;
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(const BF16Vec16 &v) : reg({
|
||||
vcvtq_low_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_low_f32_bf16(v.reg.val[1]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[1])
|
||||
}) {};
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(const BF16Vec16& v)
|
||||
: reg({vcvtq_low_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_low_f32_bf16(v.reg.val[1]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[1])}) {};
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {};
|
||||
#endif
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
|
||||
#endif
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16 &v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
explicit FP32Vec16(const FP16Vec16& v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
};
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
vaddq_f32(reg.val[2], b.reg.val[2]),
|
||||
vaddq_f32(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
vaddq_f32(reg.val[2], b.reg.val[2]),
|
||||
vaddq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||
vmulq_f32(reg.val[1], b.reg.val[1]),
|
||||
vmulq_f32(reg.val[2], b.reg.val[2]),
|
||||
vmulq_f32(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||
vmulq_f32(reg.val[1], b.reg.val[1]),
|
||||
vmulq_f32(reg.val[2], b.reg.val[2]),
|
||||
vmulq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||
vsubq_f32(reg.val[1], b.reg.val[1]),
|
||||
vsubq_f32(reg.val[2], b.reg.val[2]),
|
||||
vsubq_f32(reg.val[3], b.reg.val[3])
|
||||
}));
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||
vsubq_f32(reg.val[1], b.reg.val[1]),
|
||||
vsubq_f32(reg.val[2], b.reg.val[2]),
|
||||
vsubq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||
vdivq_f32(reg.val[1], b.reg.val[1]),
|
||||
vdivq_f32(reg.val[2], b.reg.val[2]),
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])
|
||||
}));
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||
vdivq_f32(reg.val[1], b.reg.val[1]),
|
||||
vdivq_f32(reg.val[2], b.reg.val[2]),
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float answer = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
|
||||
return answer;
|
||||
};
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
|
||||
AliasReg ar;
|
||||
@ -479,7 +483,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return answer;
|
||||
};
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vst1q_f32(ptr, reg.val[0]);
|
||||
vst1q_f32(ptr + 4, reg.val[1]);
|
||||
vst1q_f32(ptr + 8, reg.val[2]);
|
||||
@ -487,43 +491,59 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::Half> {
|
||||
using vec_type = FP16Vec8;
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
|
||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
||||
*reinterpret_cast<__fp16 *>(ptr) = v;
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) {
|
||||
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
|
||||
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
|
||||
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
|
||||
template <>
|
||||
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
|
||||
*reinterpret_cast<__fp16*>(ptr) = v;
|
||||
}
|
||||
|
||||
reg.val[0] = vcombine_f16(low_0, high_0);
|
||||
reg.val[1] = vcombine_f16(low_1, high_1);
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
|
||||
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
|
||||
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
|
||||
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
|
||||
|
||||
reg.val[0] = vcombine_f16(low_0, high_0);
|
||||
reg.val[1] = vcombine_f16(low_1, high_1);
|
||||
};
|
||||
|
||||
inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) {
|
||||
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
|
||||
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
|
||||
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
|
||||
|
||||
reg = vcombine_f16(lower_half, upper_half);
|
||||
reg = vcombine_f16(lower_half, upper_half);
|
||||
};
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
|
||||
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
|
||||
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
|
||||
@ -531,8 +551,7 @@ inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
|
||||
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
|
||||
float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
|
||||
float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
|
||||
float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
|
||||
@ -551,22 +570,22 @@ inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
#endif
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {};
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {
|
||||
};
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3])
|
||||
}){};
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]),
|
||||
v.reg.val[3])}) {};
|
||||
#endif
|
||||
|
||||
inline void prefetch(const void *addr) {
|
||||
__builtin_prefetch(addr, 0, 1);
|
||||
};
|
||||
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); };
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
*reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v);
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
*reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v);
|
||||
};
|
||||
#endif
|
||||
};
|
||||
}; // namespace vec_op
|
@ -9,38 +9,40 @@
|
||||
namespace vec_op {
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||
std::cout << #NAME << " exit." << std::endl;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||
};
|
||||
|
||||
@ -68,12 +70,14 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
|
||||
__vector signed short reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {}
|
||||
explicit BF16Vec8(const void* ptr)
|
||||
: reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {}
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; }
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__vector signed short*>(ptr) = reg;
|
||||
}
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
@ -81,18 +85,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
ss16x8x2_t reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr) {
|
||||
explicit BF16Vec16(const void* ptr) {
|
||||
// Load 256 bits in two parts
|
||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr);
|
||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr);
|
||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
|
||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
|
||||
}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void *ptr) const {
|
||||
void save(void* ptr) const {
|
||||
// Save 256 bits in two parts
|
||||
vec_xst(reg.val[0], 0, (signed short *)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short *)ptr);
|
||||
vec_xst(reg.val[0], 0, (signed short*)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short*)ptr);
|
||||
}
|
||||
};
|
||||
|
||||
@ -102,19 +106,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
constexpr static int VEC_ELEM_NUM = 32;
|
||||
|
||||
ss16x8x4_t reg;
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg(*reinterpret_cast<const ss16x8x4_t *>(ptr)) {}
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
|
||||
|
||||
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg
|
||||
}) {}
|
||||
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<ss16x8x4_t *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
@ -130,11 +130,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
|
||||
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {}
|
||||
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
|
||||
|
||||
explicit FP32Vec4(__vector float data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
@ -156,19 +156,19 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
reg.val[1] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const float *ptr) {
|
||||
explicit FP32Vec8(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) {
|
||||
explicit FP32Vec8(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v) {
|
||||
explicit FP32Vec8(const BF16Vec8& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
|
||||
}
|
||||
@ -177,7 +177,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -230,23 +231,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
}
|
||||
@ -275,7 +280,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg.val[3] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) {
|
||||
explicit FP32Vec16(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
reg.val[2] = vec_xl(32, ptr);
|
||||
@ -284,78 +289,76 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) {
|
||||
explicit FP32Vec16(const FP32Vec16& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[2];
|
||||
reg.val[3] = data.reg.val[3];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data) {
|
||||
explicit FP32Vec16(const FP32Vec4& data) {
|
||||
reg.val[0] = data.reg;
|
||||
reg.val[1] = data.reg;
|
||||
reg.val[2] = data.reg;
|
||||
reg.val[3] = data.reg;
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data) {
|
||||
explicit FP32Vec16(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v) {
|
||||
explicit FP32Vec16(const BF16Vec16& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
|
||||
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
|
||||
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_mul(reg.val[0], b.reg.val[0]),
|
||||
vec_mul(reg.val[1], b.reg.val[1]),
|
||||
vec_mul(reg.val[2], b.reg.val[2]),
|
||||
vec_mul(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
|
||||
vec_mul(reg.val[1], b.reg.val[1]),
|
||||
vec_mul(reg.val[2], b.reg.val[2]),
|
||||
vec_mul(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_add(reg.val[0], b.reg.val[0]),
|
||||
vec_add(reg.val[1], b.reg.val[1]),
|
||||
vec_add(reg.val[2], b.reg.val[2]),
|
||||
vec_add(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
|
||||
vec_add(reg.val[1], b.reg.val[1]),
|
||||
vec_add(reg.val[2], b.reg.val[2]),
|
||||
vec_add(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_sub(reg.val[0], b.reg.val[0]),
|
||||
vec_sub(reg.val[1], b.reg.val[1]),
|
||||
vec_sub(reg.val[2], b.reg.val[2]),
|
||||
vec_sub(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
|
||||
vec_sub(reg.val[1], b.reg.val[1]),
|
||||
vec_sub(reg.val[2], b.reg.val[2]),
|
||||
vec_sub(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_div(reg.val[0], b.reg.val[0]),
|
||||
vec_div(reg.val[1], b.reg.val[1]),
|
||||
vec_div(reg.val[2], b.reg.val[2]),
|
||||
vec_div(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
|
||||
vec_div(reg.val[1], b.reg.val[1]),
|
||||
vec_div(reg.val[2], b.reg.val[2]),
|
||||
vec_div(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
|
||||
AliasReg ar;
|
||||
@ -368,7 +371,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return result;
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
vec_xst(reg.val[2], 32, ptr);
|
||||
@ -376,43 +379,62 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc = acc + a * b;
|
||||
}
|
||||
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
||||
reinterpret_cast<c10::BFloat16 *>(&v);
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||
reinterpret_cast<c10::BFloat16*>(&v);
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
#ifndef __VEC_CLASS_FP_NAN
|
||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||
#endif
|
||||
|
||||
const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 };
|
||||
const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13,
|
||||
16, 17, 20, 21, 24, 25, 28, 29};
|
||||
#ifndef _ARCH_PWR10
|
||||
const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff };
|
||||
const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 };
|
||||
const static __vector unsigned int sh16 = { 16, 16, 16, 16 };
|
||||
const static __vector unsigned int one = { 1, 1, 1, 1 };
|
||||
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
|
||||
0x00007fff};
|
||||
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
|
||||
0x7fc00000};
|
||||
const static __vector unsigned int sh16 = {16, 16, 16, 16};
|
||||
const static __vector unsigned int one = {1, 1, 1, 1};
|
||||
#endif
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
|
||||
#ifdef _ARCH_PWR10
|
||||
__vector signed short ret[2];
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[1]);
|
||||
reg = vec_perm(ret[0], ret[1], omask);
|
||||
#elif defined(_ARCH_PWR9)
|
||||
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
|
||||
@ -425,8 +447,10 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
||||
__vector unsigned int rnd1 = vec_add(lsb1, bias);
|
||||
inp0 = vec_add(inp0, rnd0);
|
||||
inp1 = vec_add(inp1, rnd1);
|
||||
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel0 =
|
||||
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 =
|
||||
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
inp0 = vec_sel(inp0, nan, sel0);
|
||||
inp1 = vec_sel(inp1, nan, sel1);
|
||||
inp0 = vec_sr(inp0, sh16);
|
||||
@ -435,13 +459,17 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
#ifdef _ARCH_PWR10
|
||||
__vector signed short ret[4];
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
|
||||
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]);
|
||||
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]);
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[1]);
|
||||
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[2]);
|
||||
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[3]);
|
||||
reg.val[0] = vec_perm(ret[0], ret[1], omask);
|
||||
reg.val[1] = vec_perm(ret[2], ret[3], omask);
|
||||
#elif defined(_ARCH_PWR9)
|
||||
@ -465,10 +493,14 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
inp1 = vec_add(inp1, rnd1);
|
||||
inp2 = vec_add(inp2, rnd2);
|
||||
inp3 = vec_add(inp3, rnd3);
|
||||
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel0 =
|
||||
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 =
|
||||
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel2 =
|
||||
vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel3 =
|
||||
vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
|
||||
inp0 = vec_sel(inp0, nan, sel0);
|
||||
inp1 = vec_sel(inp1, nan, sel1);
|
||||
inp2 = vec_sel(inp2, nan, sel2);
|
||||
@ -482,10 +514,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void prefetch(const void *addr) {
|
||||
inline void prefetch(const void* addr) {
|
||||
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
|
||||
}
|
||||
|
||||
}; // namespace vec_op
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
|
@ -11,39 +11,40 @@ static_assert(false, "AVX2 must be supported for the current implementation.");
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||
};
|
||||
|
||||
@ -55,12 +56,12 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit FP16Vec8(const void *ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
||||
explicit FP16Vec8(const void* ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec8(const FP32Vec8 &);
|
||||
explicit FP16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
@ -68,12 +69,12 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
explicit FP16Vec16(const void *ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
||||
explicit FP16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16 &);
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -87,12 +88,12 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
||||
explicit BF16Vec8(const void* ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
@ -100,12 +101,12 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
||||
explicit BF16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -120,11 +121,11 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||
|
||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
||||
explicit BF16Vec32(BF16Vec8& vec8_data)
|
||||
: reg((__m512i)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
|
||||
(__m128i)vec8_data.reg),
|
||||
@ -132,7 +133,7 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
(__m128i)vec8_data.reg, 2),
|
||||
(__m128i)vec8_data.reg, 3)) {}
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m512i*>(ptr) = reg; }
|
||||
};
|
||||
#else
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -141,24 +142,24 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
__m256i reg_low;
|
||||
__m256i reg_high;
|
||||
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
|
||||
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg_low(_mm256_loadu_si256((__m256i const*)ptr)),
|
||||
reg_high(_mm256_loadu_si256((__m256i const*)ptr + 1)) {}
|
||||
|
||||
explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
|
||||
reg_high(high) {}
|
||||
explicit BF16Vec32(__m256i low, __m256i high)
|
||||
: reg_low(low), reg_high(high) {}
|
||||
|
||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
||||
explicit BF16Vec32(BF16Vec8& vec8_data)
|
||||
: reg_low((__m256i)_mm256_inserti32x4(
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)),
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)),
|
||||
reg_high((__m256i)_mm256_inserti32x4(
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
|
||||
void save(void *ptr) const {
|
||||
*reinterpret_cast<__m256i *>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__m256i*>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@ -176,11 +177,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
|
||||
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
|
||||
explicit FP32Vec4(const float* ptr) : reg(_mm_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec4(__m128 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
@ -196,15 +197,15 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
|
||||
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
||||
explicit FP32Vec8(const float* ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec8(__m256 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}
|
||||
|
||||
explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {}
|
||||
explicit FP32Vec8(const FP16Vec8& v) : reg(_mm256_cvtph_ps(v.reg)) {}
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v)
|
||||
explicit FP32Vec8(const BF16Vec8& v)
|
||||
: reg(_mm256_castsi256_ps(
|
||||
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
|
||||
|
||||
@ -212,7 +213,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -244,27 +246,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
erf(ar.values[1]), erf(ar.values[0])));
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_add_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||
void save(float* ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
struct INT32Vec16: public Vec<INT32Vec16> {
|
||||
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
__m512i reg;
|
||||
@ -273,11 +275,10 @@ struct INT32Vec16: public Vec<INT32Vec16> {
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {}
|
||||
explicit INT32Vec16(const void* data_ptr)
|
||||
: reg(_mm512_loadu_epi32(data_ptr)) {}
|
||||
|
||||
void save(int32_t* ptr) const {
|
||||
_mm512_storeu_epi32(ptr, reg);
|
||||
}
|
||||
void save(int32_t* ptr) const { _mm512_storeu_epi32(ptr, reg); }
|
||||
|
||||
void save(int32_t* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -301,11 +302,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data)
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg((__m512)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
|
||||
@ -313,36 +314,37 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
(__m128i)data.reg, 2),
|
||||
(__m128i)data.reg, 3)) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data)
|
||||
explicit FP32Vec16(const FP32Vec8& data)
|
||||
: reg((__m512)_mm512_inserti32x8(
|
||||
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v)
|
||||
explicit FP32Vec16(const BF16Vec16& v)
|
||||
: reg(_mm512_castsi512_ps(
|
||||
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
||||
explicit FP32Vec16(const FP16Vec16& v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const INT32Vec16 &v)
|
||||
: reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {}
|
||||
explicit FP32Vec16(const INT32Vec16& v)
|
||||
: reg(_mm512_cvt_roundepi32_ps(
|
||||
v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_add_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
@ -370,9 +372,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(_mm512_abs_ps(reg));
|
||||
}
|
||||
FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); }
|
||||
|
||||
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
||||
|
||||
@ -380,14 +380,15 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
|
||||
return _mm512_mask_reduce_add_ps(mask, reg);
|
||||
}
|
||||
|
||||
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -407,32 +408,30 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
__m256 reg_low;
|
||||
__m256 reg_high;
|
||||
|
||||
explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
|
||||
reg_high(_mm256_set1_ps(v)) {}
|
||||
explicit FP32Vec16(float v)
|
||||
: reg_low(_mm256_set1_ps(v)), reg_high(_mm256_set1_ps(v)) {}
|
||||
|
||||
explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
|
||||
reg_high(_mm256_set1_ps(0.0)) {}
|
||||
explicit FP32Vec16()
|
||||
: reg_low(_mm256_set1_ps(0.0)), reg_high(_mm256_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
|
||||
reg_high(_mm256_loadu_ps(ptr + 8)) {}
|
||||
explicit FP32Vec16(const float* ptr)
|
||||
: reg_low(_mm256_loadu_ps(ptr)), reg_high(_mm256_loadu_ps(ptr + 8)) {}
|
||||
|
||||
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
|
||||
reg_high(data.reg_high) {}
|
||||
explicit FP32Vec16(const FP32Vec16& data)
|
||||
: reg_low(data.reg_low), reg_high(data.reg_high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data)
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg_low((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg),
|
||||
(__m128i)data.reg, 1)),
|
||||
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
|
||||
reg_high((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg),
|
||||
(__m128i)data.reg, 1)) {}
|
||||
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data)
|
||||
explicit FP32Vec16(const FP32Vec8& data)
|
||||
: reg_low(data.reg), reg_high(data.reg) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16 &v) {
|
||||
explicit FP32Vec16(const FP16Vec16& v) {
|
||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||
|
||||
@ -440,9 +439,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg_high = _mm256_cvtph_ps(high);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v) {
|
||||
explicit FP32Vec16(const BF16Vec16& v) {
|
||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||
|
||||
@ -456,24 +455,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg_high = _mm256_castsi256_ps(v_high_shifted);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
|
||||
_mm256_mul_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
|
||||
_mm256_add_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
|
||||
_mm256_sub_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
|
||||
_mm256_div_ps(reg_high, b.reg_high));
|
||||
}
|
||||
@ -484,7 +483,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return low.reduce_sum() + high.reduce_sum();
|
||||
}
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
float sum = 0.0;
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
@ -507,7 +507,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return sum;
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
_mm256_storeu_ps(ptr, reg_low);
|
||||
_mm256_storeu_ps(ptr + 8, reg_high);
|
||||
}
|
||||
@ -515,7 +515,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
#endif
|
||||
|
||||
#ifdef __AVX512F__
|
||||
struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
__m128i reg;
|
||||
@ -524,13 +524,11 @@ struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) : reg(
|
||||
_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
) {}
|
||||
explicit INT8Vec16(const FP32Vec16& vec)
|
||||
: reg(_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(
|
||||
vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))) {}
|
||||
|
||||
void save(int8_t* ptr) const {
|
||||
_mm_storeu_epi8(ptr, reg);
|
||||
}
|
||||
void save(int8_t* ptr) const { _mm_storeu_epi8(ptr, reg); }
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -540,71 +538,92 @@ struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::Half> {
|
||||
using vec_type = FP16Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc = acc + a * b;
|
||||
}
|
||||
|
||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
||||
*reinterpret_cast<unsigned short *>(ptr) =
|
||||
template <>
|
||||
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
|
||||
*reinterpret_cast<unsigned short*>(ptr) =
|
||||
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
}
|
||||
|
||||
inline FP16Vec8::FP16Vec8(const FP32Vec8 &v)
|
||||
inline FP16Vec8::FP16Vec8(const FP32Vec8& v)
|
||||
: reg(_mm256_cvtps_ph(v.reg,
|
||||
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm512_cvtps_ph(v.reg,
|
||||
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||
#else
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
||||
: reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm256_insertf128_si256(
|
||||
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
|
||||
FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||
#endif
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
*reinterpret_cast<__bfloat16*>(ptr) = _mm_cvtness_sbh(v);
|
||||
}
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
|
||||
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
|
||||
}
|
||||
#else
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
||||
reinterpret_cast<c10::BFloat16 *>(&v);
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||
reinterpret_cast<c10::BFloat16*>(&v);
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
#ifdef __AVX512F__
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg(_mm256_cvtepi32_epi16(
|
||||
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm512_cvtepi32_epi16(
|
||||
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
|
||||
#else
|
||||
namespace{
|
||||
#else
|
||||
namespace {
|
||||
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
||||
__m256i ai = _mm256_castps_si256(a);
|
||||
ai = _mm256_srli_epi32(ai, 16);
|
||||
@ -612,21 +631,21 @@ __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
||||
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
|
||||
return _mm256_extracti128_si256(ai, 0);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
|
||||
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
|
||||
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BF16__
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BF16__
|
||||
|
||||
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
|
||||
}; // namespace vec_op
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
|
@ -27,8 +27,7 @@
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
||||
device);
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
|
@ -25,10 +25,12 @@ Check out the [building from source](#build-from-source) documentation for detai
|
||||
```bash
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# linting and formatting
|
||||
bash format.sh
|
||||
# Static type checking
|
||||
mypy
|
||||
# Linting, formatting and static type checking
|
||||
pre-commit install
|
||||
|
||||
# You can manually run pre-commit with
|
||||
pre-commit run --all-files
|
||||
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
```
|
||||
@ -88,7 +90,8 @@ If the PR spans more than one category, please include all relevant prefixes.
|
||||
The PR needs to meet the following code quality standards:
|
||||
|
||||
- We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||
- Pass all linter checks. Please use <gh-file:format.sh> to format your code.
|
||||
- Pass all linter checks. Please use `pre-commit` to format your code. See
|
||||
<https://pre-commit.com/#usage> if `pre-commit` is new to you.
|
||||
- The code needs to be well-documented to ensure future contributors can easily
|
||||
understand the code.
|
||||
- Include sufficient tests to ensure the project stays correct and robust. This
|
||||
|
321
format.sh
321
format.sh
@ -1,321 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# YAPF formatter, adapted from ray and skypilot.
|
||||
#
|
||||
# Usage:
|
||||
# # Do work and commit your work.
|
||||
|
||||
# # Format files that differ from origin/main.
|
||||
# bash format.sh
|
||||
|
||||
# # Commit changed files with message 'Run yapf and ruff'
|
||||
#
|
||||
#
|
||||
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
|
||||
# You are encouraged to run this locally before pushing changes for review.
|
||||
|
||||
# Cause the script to exit if a single command fails
|
||||
set -eo pipefail
|
||||
|
||||
# this stops git rev-parse from failing if we run this from the .git directory
|
||||
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
|
||||
ROOT="$(git rev-parse --show-toplevel)"
|
||||
builtin cd "$ROOT" || exit 1
|
||||
|
||||
check_command() {
|
||||
if ! command -v "$1" &> /dev/null; then
|
||||
echo "❓❓$1 is not installed, please run \`pip install -r requirements-lint.txt\`"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_command yapf
|
||||
check_command ruff
|
||||
check_command mypy
|
||||
check_command codespell
|
||||
check_command isort
|
||||
check_command clang-format
|
||||
|
||||
YAPF_VERSION=$(yapf --version | awk '{print $2}')
|
||||
RUFF_VERSION=$(ruff --version | awk '{print $2}')
|
||||
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
||||
CODESPELL_VERSION=$(codespell --version)
|
||||
ISORT_VERSION=$(isort --vn)
|
||||
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
|
||||
PYMARKDOWNLNT_VERSION=$(pymarkdownlnt version | awk '{print $1}')
|
||||
|
||||
# # params: tool name, tool version, required version
|
||||
tool_version_check() {
|
||||
expected=$(grep "$1" requirements-lint.txt | cut -d'=' -f3)
|
||||
if [[ "$2" != "$expected" ]]; then
|
||||
echo "❓❓Wrong $1 version installed: $expected is required, not $2."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
tool_version_check "yapf" "$YAPF_VERSION"
|
||||
tool_version_check "ruff" "$RUFF_VERSION"
|
||||
tool_version_check "mypy" "$MYPY_VERSION"
|
||||
tool_version_check "isort" "$ISORT_VERSION"
|
||||
tool_version_check "codespell" "$CODESPELL_VERSION"
|
||||
tool_version_check "clang-format" "$CLANGFORMAT_VERSION"
|
||||
tool_version_check "pymarkdownlnt" "$PYMARKDOWNLNT_VERSION"
|
||||
|
||||
YAPF_FLAGS=(
|
||||
'--recursive'
|
||||
'--parallel'
|
||||
)
|
||||
|
||||
YAPF_EXCLUDES=(
|
||||
'--exclude' 'build/**'
|
||||
)
|
||||
|
||||
# Format specified files
|
||||
format() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
|
||||
}
|
||||
|
||||
# Format files that differ from main branch. Ignores dirs that are not slated
|
||||
# for autoformat yet.
|
||||
format_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause yapf to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
|
||||
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# Format all files
|
||||
format_all() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" .
|
||||
}
|
||||
|
||||
## This flag formats individual files. --files *must* be the first command line
|
||||
## arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
format "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is formatted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
format_all
|
||||
else
|
||||
# Format only the files that changed in last commit.
|
||||
format_changed
|
||||
fi
|
||||
echo 'vLLM yapf: Done'
|
||||
|
||||
# Run mypy
|
||||
echo 'vLLM mypy:'
|
||||
tools/mypy.sh
|
||||
echo 'vLLM mypy: Done'
|
||||
|
||||
|
||||
# If git diff returns a file that is in the skip list, the file may be checked anyway:
|
||||
# https://github.com/codespell-project/codespell/issues/1915
|
||||
# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem
|
||||
CODESPELL_EXCLUDES=(
|
||||
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**'
|
||||
)
|
||||
|
||||
# check spelling of specified files
|
||||
spell_check() {
|
||||
codespell "$@"
|
||||
}
|
||||
|
||||
spell_check_all(){
|
||||
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}"
|
||||
}
|
||||
|
||||
# Spelling check of files that differ from main branch.
|
||||
spell_check_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause ruff to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||
codespell "${CODESPELL_EXCLUDES[@]}"
|
||||
fi
|
||||
}
|
||||
|
||||
# Run Codespell
|
||||
## This flag runs spell check of individual files. --files *must* be the first command line
|
||||
## arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
spell_check "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is linted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
spell_check_all
|
||||
else
|
||||
# Check spelling only of the files that changed in last commit.
|
||||
spell_check_changed
|
||||
fi
|
||||
echo 'vLLM codespell: Done'
|
||||
|
||||
|
||||
# Lint specified files
|
||||
lint() {
|
||||
ruff check "$@"
|
||||
}
|
||||
|
||||
# Lint files that differ from main branch. Ignores dirs that are not slated
|
||||
# for autolint yet.
|
||||
lint_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause ruff to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||
ruff check
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# Run Ruff
|
||||
### This flag lints individual files. --files *must* be the first command line
|
||||
### arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
lint "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is linted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
lint vllm tests
|
||||
else
|
||||
# Format only the files that changed in last commit.
|
||||
lint_changed
|
||||
fi
|
||||
echo 'vLLM ruff: Done'
|
||||
|
||||
# check spelling of specified files
|
||||
isort_check() {
|
||||
isort "$@"
|
||||
}
|
||||
|
||||
isort_check_all(){
|
||||
isort .
|
||||
}
|
||||
|
||||
# Spelling check of files that differ from main branch.
|
||||
isort_check_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause ruff to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||
isort
|
||||
fi
|
||||
}
|
||||
|
||||
# Run Isort
|
||||
# This flag runs spell check of individual files. --files *must* be the first command line
|
||||
# arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
isort_check "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is linted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
isort_check_all
|
||||
else
|
||||
# Check spelling only of the files that changed in last commit.
|
||||
isort_check_changed
|
||||
fi
|
||||
echo 'vLLM isort: Done'
|
||||
|
||||
# Clang-format section
|
||||
# Exclude some files for formatting because they are vendored
|
||||
# NOTE: Keep up to date with .github/workflows/clang-format.yml
|
||||
CLANG_FORMAT_EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/quantization/gguf/ggml-common.h'
|
||||
'csrc/quantization/gguf/dequantize.cuh'
|
||||
'csrc/quantization/gguf/vecdotq.cuh'
|
||||
'csrc/quantization/gguf/mmq.cuh'
|
||||
'csrc/quantization/gguf/mmvq.cuh'
|
||||
)
|
||||
|
||||
# Format specified files with clang-format
|
||||
clang_format() {
|
||||
clang-format -i "$@"
|
||||
}
|
||||
|
||||
# Format files that differ from main branch with clang-format.
|
||||
clang_format_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause clang-format to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
# Get the list of changed files, excluding the specified ones
|
||||
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | (grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") || echo -e))
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "$changed_files" | xargs -P 5 clang-format -i
|
||||
fi
|
||||
}
|
||||
|
||||
# Format all files with clang-format
|
||||
clang_format_all() {
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
|
||||
| xargs clang-format -i
|
||||
}
|
||||
|
||||
# Run clang-format
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
clang_format "${@:2}"
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
clang_format_all
|
||||
else
|
||||
clang_format_changed
|
||||
fi
|
||||
echo 'vLLM clang-format: Done'
|
||||
|
||||
echo 'vLLM actionlint:'
|
||||
tools/actionlint.sh -color
|
||||
echo 'vLLM actionlint: Done'
|
||||
|
||||
echo 'vLLM shellcheck:'
|
||||
tools/shellcheck.sh
|
||||
echo 'vLLM shellcheck: Done'
|
||||
|
||||
echo 'excalidraw png check:'
|
||||
tools/png-lint.sh
|
||||
echo 'excalidraw png check: Done'
|
||||
|
||||
if ! git diff --quiet &>/dev/null; then
|
||||
echo
|
||||
echo "🔍🔍There are files changed by the format checker or by you that are not added and committed:"
|
||||
git --no-pager diff --name-only
|
||||
echo "🔍🔍Format checker passed, but please add, commit and push all the files above to include changes made by the format checker."
|
||||
|
||||
exit 1
|
||||
else
|
||||
echo "✨🎉 Format check passed! Congratulations! 🎉✨"
|
||||
fi
|
||||
|
||||
echo 'vLLM doc-lint:'
|
||||
tools/doc-lint.sh
|
||||
echo 'vLLM doc-lint: Done'
|
@ -15,6 +15,11 @@ build-backend = "setuptools.build_meta"
|
||||
[tool.setuptools_scm]
|
||||
# version_file = "vllm/_version.py" # currently handled by `setup.py:get_version()`
|
||||
|
||||
[tool.yapfignore]
|
||||
ignore_patterns = [
|
||||
"build/**",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
# Allow lines to be as long as 80.
|
||||
line-length = 80
|
||||
@ -52,6 +57,9 @@ ignore = [
|
||||
"B007",
|
||||
# f-string format
|
||||
"UP032",
|
||||
# Python 3.8 typing
|
||||
"UP006", "UP035",
|
||||
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
@ -1,15 +1,2 @@
|
||||
# formatting
|
||||
yapf==0.32.0
|
||||
toml==0.10.2
|
||||
tomli==2.0.2
|
||||
ruff==0.6.5
|
||||
codespell==2.3.0
|
||||
isort==5.13.2
|
||||
clang-format==18.1.5
|
||||
pymarkdownlnt==0.9.26
|
||||
|
||||
# type checking
|
||||
mypy==1.11.1
|
||||
types-PyYAML
|
||||
types-requests
|
||||
types-setuptools
|
||||
pre-commit==4.0.1
|
||||
|
@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if command -v actionlint &> /dev/null; then
|
||||
actionlint "$@"
|
||||
exit 0
|
||||
elif [ -x ./actionlint ]; then
|
||||
./actionlint "$@"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# download a binary to the current directory - v1.7.3
|
||||
bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/aa0a7be8e566b096e64a5df8ff290ec24fa58fbc/scripts/download-actionlint.bash)
|
||||
./actionlint "$@"
|
@ -1,3 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
pymarkdownlnt scan docs -r
|
Loading…
x
Reference in New Issue
Block a user