Move linting to pre-commit
(#11975)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
51ef828f10
commit
3ea7b94523
@ -43,7 +43,7 @@ main() {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
# The figures should be genereated by a separate process outside the CI/CD pipeline
|
# The figures should be generated by a separate process outside the CI/CD pipeline
|
||||||
|
|
||||||
# # generate figures
|
# # generate figures
|
||||||
# python3 -m pip install tabulate pandas matplotlib
|
# python3 -m pip install tabulate pandas matplotlib
|
||||||
|
40
.github/workflows/actionlint.yml
vendored
40
.github/workflows/actionlint.yml
vendored
@ -1,40 +0,0 @@
|
|||||||
name: Lint GitHub Actions workflows
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- "main"
|
|
||||||
paths:
|
|
||||||
- '.github/workflows/*.ya?ml'
|
|
||||||
- '.github/workflows/actionlint.*'
|
|
||||||
- '.github/workflows/matchers/actionlint.json'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- "main"
|
|
||||||
paths:
|
|
||||||
- '.github/workflows/*.ya?ml'
|
|
||||||
- '.github/workflows/actionlint.*'
|
|
||||||
- '.github/workflows/matchers/actionlint.json'
|
|
||||||
|
|
||||||
env:
|
|
||||||
LC_ALL: en_US.UTF-8
|
|
||||||
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
actionlint:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: "Checkout"
|
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: "Run actionlint"
|
|
||||||
run: |
|
|
||||||
echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
|
||||||
tools/actionlint.sh -color
|
|
53
.github/workflows/clang-format.yml
vendored
53
.github/workflows/clang-format.yml
vendored
@ -1,53 +0,0 @@
|
|||||||
name: clang-format
|
|
||||||
|
|
||||||
on:
|
|
||||||
# Trigger the workflow on push or pull request,
|
|
||||||
# but only for the main branch
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- '**/*.h'
|
|
||||||
- '**/*.cpp'
|
|
||||||
- '**/*.cu'
|
|
||||||
- '**/*.cuh'
|
|
||||||
- '.github/workflows/clang-format.yml'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- '**/*.h'
|
|
||||||
- '**/*.cpp'
|
|
||||||
- '**/*.cu'
|
|
||||||
- '**/*.cuh'
|
|
||||||
- '.github/workflows/clang-format.yml'
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
clang-format:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.11"]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install clang-format==18.1.5
|
|
||||||
- name: Running clang-format
|
|
||||||
run: |
|
|
||||||
EXCLUDES=(
|
|
||||||
'csrc/moe/topk_softmax_kernels.cu'
|
|
||||||
'csrc/quantization/gguf/ggml-common.h'
|
|
||||||
'csrc/quantization/gguf/dequantize.cuh'
|
|
||||||
'csrc/quantization/gguf/vecdotq.cuh'
|
|
||||||
'csrc/quantization/gguf/mmq.cuh'
|
|
||||||
'csrc/quantization/gguf/mmvq.cuh'
|
|
||||||
)
|
|
||||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
|
||||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
|
||||||
| xargs clang-format --dry-run --Werror
|
|
45
.github/workflows/codespell.yml
vendored
45
.github/workflows/codespell.yml
vendored
@ -1,45 +0,0 @@
|
|||||||
name: codespell
|
|
||||||
|
|
||||||
on:
|
|
||||||
# Trigger the workflow on push or pull request,
|
|
||||||
# but only for the main branch
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- "**/*.py"
|
|
||||||
- "**/*.md"
|
|
||||||
- "**/*.rst"
|
|
||||||
- pyproject.toml
|
|
||||||
- requirements-lint.txt
|
|
||||||
- .github/workflows/codespell.yml
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- "**/*.py"
|
|
||||||
- "**/*.md"
|
|
||||||
- "**/*.rst"
|
|
||||||
- pyproject.toml
|
|
||||||
- requirements-lint.txt
|
|
||||||
- .github/workflows/codespell.yml
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
codespell:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.12"]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install -r requirements-lint.txt
|
|
||||||
- name: Spelling check with codespell
|
|
||||||
run: |
|
|
||||||
codespell --toml pyproject.toml
|
|
32
.github/workflows/doc-lint.yml
vendored
32
.github/workflows/doc-lint.yml
vendored
@ -1,32 +0,0 @@
|
|||||||
name: Lint documentation
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- "docs/**"
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- "docs/**"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
doc-lint:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.12"]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install -r requirements-lint.txt
|
|
||||||
- name: Linting docs
|
|
||||||
run: tools/doc-lint.sh
|
|
20
.github/workflows/dummy.yml
vendored
Normal file
20
.github/workflows/dummy.yml
vendored
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
name: dummy-checks
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
mypy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.12"]
|
||||||
|
steps:
|
||||||
|
- run: echo "This is a dummy step that always passes"
|
||||||
|
ruff:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.12"]
|
||||||
|
steps:
|
||||||
|
- run: echo "This is a dummy step that always passes"
|
17
.github/workflows/matchers/ruff.json
vendored
17
.github/workflows/matchers/ruff.json
vendored
@ -1,17 +0,0 @@
|
|||||||
{
|
|
||||||
"problemMatcher": [
|
|
||||||
{
|
|
||||||
"owner": "ruff",
|
|
||||||
"pattern": [
|
|
||||||
{
|
|
||||||
"regexp": "^(.+?):(\\d+):(\\d+): (\\w+): (.+)$",
|
|
||||||
"file": 1,
|
|
||||||
"line": 2,
|
|
||||||
"column": 3,
|
|
||||||
"code": 4,
|
|
||||||
"message": 5
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
51
.github/workflows/mypy.yaml
vendored
51
.github/workflows/mypy.yaml
vendored
@ -1,51 +0,0 @@
|
|||||||
name: mypy
|
|
||||||
|
|
||||||
on:
|
|
||||||
# Trigger the workflow on push or pull request,
|
|
||||||
# but only for the main branch
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- '**/*.py'
|
|
||||||
- '.github/workflows/mypy.yaml'
|
|
||||||
- 'tools/mypy.sh'
|
|
||||||
- 'pyproject.toml'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
# This workflow is only relevant when one of the following files changes.
|
|
||||||
# However, we have github configured to expect and require this workflow
|
|
||||||
# to run and pass before github with auto-merge a pull request. Until github
|
|
||||||
# allows more flexible auto-merge policy, we can just run this on every PR.
|
|
||||||
# It doesn't take that long to run, anyway.
|
|
||||||
#paths:
|
|
||||||
# - '**/*.py'
|
|
||||||
# - '.github/workflows/mypy.yaml'
|
|
||||||
# - 'tools/mypy.sh'
|
|
||||||
# - 'pyproject.toml'
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
mypy:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install mypy==1.11.1
|
|
||||||
pip install types-setuptools
|
|
||||||
pip install types-PyYAML
|
|
||||||
pip install types-requests
|
|
||||||
pip install types-setuptools
|
|
||||||
- name: Mypy
|
|
||||||
run: |
|
|
||||||
echo "::add-matcher::.github/workflows/matchers/mypy.json"
|
|
||||||
tools/mypy.sh 1 ${{ matrix.python-version }}
|
|
37
.github/workflows/png-lint.yml
vendored
37
.github/workflows/png-lint.yml
vendored
@ -1,37 +0,0 @@
|
|||||||
name: Lint PNG exports from excalidraw
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- "main"
|
|
||||||
paths:
|
|
||||||
- '*.excalidraw.png'
|
|
||||||
- '.github/workflows/png-lint.yml'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- "main"
|
|
||||||
paths:
|
|
||||||
- '*.excalidraw.png'
|
|
||||||
- '.github/workflows/png-lint.yml'
|
|
||||||
|
|
||||||
env:
|
|
||||||
LC_ALL: en_US.UTF-8
|
|
||||||
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
actionlint:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: "Checkout"
|
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: "Run png-lint.sh to check excalidraw exported images"
|
|
||||||
run: |
|
|
||||||
tools/png-lint.sh
|
|
17
.github/workflows/pre-commit.yml
vendored
Normal file
17
.github/workflows/pre-commit.yml
vendored
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
name: pre-commit
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pre-commit:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||||
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
52
.github/workflows/ruff.yml
vendored
52
.github/workflows/ruff.yml
vendored
@ -1,52 +0,0 @@
|
|||||||
name: ruff
|
|
||||||
|
|
||||||
on:
|
|
||||||
# Trigger the workflow on push or pull request,
|
|
||||||
# but only for the main branch
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- "**/*.py"
|
|
||||||
- pyproject.toml
|
|
||||||
- requirements-lint.txt
|
|
||||||
- .github/workflows/matchers/ruff.json
|
|
||||||
- .github/workflows/ruff.yml
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
# This workflow is only relevant when one of the following files changes.
|
|
||||||
# However, we have github configured to expect and require this workflow
|
|
||||||
# to run and pass before github with auto-merge a pull request. Until github
|
|
||||||
# allows more flexible auto-merge policy, we can just run this on every PR.
|
|
||||||
# It doesn't take that long to run, anyway.
|
|
||||||
#paths:
|
|
||||||
# - "**/*.py"
|
|
||||||
# - pyproject.toml
|
|
||||||
# - requirements-lint.txt
|
|
||||||
# - .github/workflows/matchers/ruff.json
|
|
||||||
# - .github/workflows/ruff.yml
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
ruff:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.12"]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install -r requirements-lint.txt
|
|
||||||
- name: Analysing the code with ruff
|
|
||||||
run: |
|
|
||||||
echo "::add-matcher::.github/workflows/matchers/ruff.json"
|
|
||||||
ruff check --output-format github .
|
|
||||||
- name: Run isort
|
|
||||||
run: |
|
|
||||||
isort . --check-only
|
|
37
.github/workflows/shellcheck.yml
vendored
37
.github/workflows/shellcheck.yml
vendored
@ -1,37 +0,0 @@
|
|||||||
name: Lint shell scripts
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- "main"
|
|
||||||
paths:
|
|
||||||
- '**/*.sh'
|
|
||||||
- '.github/workflows/shellcheck.yml'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- "main"
|
|
||||||
paths:
|
|
||||||
- '**/*.sh'
|
|
||||||
- '.github/workflows/shellcheck.yml'
|
|
||||||
|
|
||||||
env:
|
|
||||||
LC_ALL: en_US.UTF-8
|
|
||||||
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
shellcheck:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: "Checkout"
|
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: "Check shell scripts"
|
|
||||||
run: |
|
|
||||||
tools/shellcheck.sh
|
|
38
.github/workflows/yapf.yml
vendored
38
.github/workflows/yapf.yml
vendored
@ -1,38 +0,0 @@
|
|||||||
name: yapf
|
|
||||||
|
|
||||||
on:
|
|
||||||
# Trigger the workflow on push or pull request,
|
|
||||||
# but only for the main branch
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- "**/*.py"
|
|
||||||
- .github/workflows/yapf.yml
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- "**/*.py"
|
|
||||||
- .github/workflows/yapf.yml
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
yapf:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.12"]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install yapf==0.32.0
|
|
||||||
pip install toml==0.10.2
|
|
||||||
- name: Running yapf
|
|
||||||
run: |
|
|
||||||
yapf --diff --recursive .
|
|
73
.pre-commit-config.yaml
Normal file
73
.pre-commit-config.yaml
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/google/yapf
|
||||||
|
rev: v0.32.0
|
||||||
|
hooks:
|
||||||
|
- id: yapf
|
||||||
|
args: [--in-place, --verbose]
|
||||||
|
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.6.5
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--output-format, github]
|
||||||
|
- repo: https://github.com/codespell-project/codespell
|
||||||
|
rev: v2.3.0
|
||||||
|
hooks:
|
||||||
|
- id: codespell
|
||||||
|
exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*'
|
||||||
|
- repo: https://github.com/PyCQA/isort
|
||||||
|
rev: 5.13.2
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
|
rev: v18.1.5
|
||||||
|
hooks:
|
||||||
|
- id: clang-format
|
||||||
|
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))'
|
||||||
|
types_or: [c++, cuda]
|
||||||
|
args: [--style=file, --verbose]
|
||||||
|
- repo: https://github.com/jackdewinter/pymarkdown
|
||||||
|
rev: v0.9.27
|
||||||
|
hooks:
|
||||||
|
- id: pymarkdown
|
||||||
|
files: docs/.*
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||||
|
name: Run mypy for Python 3.9
|
||||||
|
entry: tools/mypy.sh 1 "3.9"
|
||||||
|
language: python
|
||||||
|
types: [python]
|
||||||
|
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||||
|
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||||
|
name: Run mypy for Python 3.10
|
||||||
|
entry: tools/mypy.sh 1 "3.10"
|
||||||
|
language: python
|
||||||
|
types: [python]
|
||||||
|
additional_dependencies: *mypy_deps
|
||||||
|
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||||
|
name: Run mypy for Python 3.11
|
||||||
|
entry: tools/mypy.sh 1 "3.11"
|
||||||
|
language: python
|
||||||
|
types: [python]
|
||||||
|
additional_dependencies: *mypy_deps
|
||||||
|
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||||
|
name: Run mypy for Python 3.12
|
||||||
|
entry: tools/mypy.sh 1 "3.12"
|
||||||
|
language: python
|
||||||
|
types: [python]
|
||||||
|
additional_dependencies: *mypy_deps
|
||||||
|
- id: shellcheck
|
||||||
|
name: Lint shell scripts
|
||||||
|
entry: tools/shellcheck.sh
|
||||||
|
language: script
|
||||||
|
types: [shell]
|
||||||
|
- id: png-lint
|
||||||
|
name: Lint PNG exports from excalidraw
|
||||||
|
entry: tools/png-lint.sh
|
||||||
|
language: script
|
||||||
|
types: [png]
|
||||||
|
- repo: https://github.com/rhysd/actionlint
|
||||||
|
rev: v1.7.6
|
||||||
|
hooks:
|
||||||
|
- id: actionlint
|
@ -32,7 +32,7 @@ class ScalarType {
|
|||||||
signed_(signed_),
|
signed_(signed_),
|
||||||
bias(bias),
|
bias(bias),
|
||||||
finite_values_only(finite_values_only),
|
finite_values_only(finite_values_only),
|
||||||
nan_repr(nan_repr){};
|
nan_repr(nan_repr) {};
|
||||||
|
|
||||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||||
return ScalarType(0, size_bits - 1, true, bias);
|
return ScalarType(0, size_bits - 1, true, bias);
|
||||||
|
@ -2,13 +2,13 @@
|
|||||||
#define CPU_TYPES_HPP
|
#define CPU_TYPES_HPP
|
||||||
|
|
||||||
#if defined(__x86_64__)
|
#if defined(__x86_64__)
|
||||||
//x86 implementation
|
// x86 implementation
|
||||||
#include "cpu_types_x86.hpp"
|
#include "cpu_types_x86.hpp"
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
//ppc implementation
|
// ppc implementation
|
||||||
#include "cpu_types_vsx.hpp"
|
#include "cpu_types_vsx.hpp"
|
||||||
#elif defined(__aarch64__)
|
#elif defined(__aarch64__)
|
||||||
//arm implementation
|
// arm implementation
|
||||||
#include "cpu_types_arm.hpp"
|
#include "cpu_types_arm.hpp"
|
||||||
#else
|
#else
|
||||||
#warning "unsupported vLLM cpu implementation"
|
#warning "unsupported vLLM cpu implementation"
|
||||||
|
@ -19,30 +19,32 @@ namespace vec_op {
|
|||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#ifndef CPU_OP_GUARD
|
#ifndef CPU_OP_GUARD
|
||||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||||
#else
|
#else
|
||||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||||
std::cout << #NAME << " invoked." << std::endl;
|
std::cout << #NAME << " invoked." << std::endl;
|
||||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||||
|
std::cout << #NAME << " exit." << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename T, T... indexes, typename F>
|
template <typename T, T... indexes, typename F>
|
||||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||||
(f(std::integral_constant<T, indexes>{}), ...);
|
(f(std::integral_constant<T, indexes>{}), ...);
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
}; // namespace
|
||||||
|
|
||||||
template <typename T, T count, typename F,
|
template <typename T, T count, typename F,
|
||||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||||
constexpr void unroll_loop(F &&f) {
|
constexpr void unroll_loop(F&& f) {
|
||||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> struct Vec {
|
template <typename T>
|
||||||
|
struct Vec {
|
||||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
|
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -54,14 +56,12 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
|
|||||||
|
|
||||||
float16x8_t reg;
|
float16x8_t reg;
|
||||||
|
|
||||||
explicit FP16Vec8(const void *ptr)
|
explicit FP16Vec8(const void* ptr)
|
||||||
: reg(vld1q_f16(static_cast<const __fp16 *>(ptr))) {};
|
: reg(vld1q_f16(static_cast<const __fp16*>(ptr))) {};
|
||||||
|
|
||||||
explicit FP16Vec8(const FP32Vec8 &);
|
explicit FP16Vec8(const FP32Vec8&);
|
||||||
|
|
||||||
void save(void *ptr) const {
|
void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); }
|
||||||
vst1q_f16(static_cast<__fp16 *>(ptr), reg);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FP16Vec16 : public Vec<FP16Vec16> {
|
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||||
@ -69,19 +69,19 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
|||||||
|
|
||||||
float16x8x2_t reg;
|
float16x8x2_t reg;
|
||||||
|
|
||||||
explicit FP16Vec16(const void *ptr) {
|
explicit FP16Vec16(const void* ptr) {
|
||||||
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
|
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
|
||||||
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP16Vec16(const FP32Vec16& vec);
|
explicit FP16Vec16(const FP32Vec16& vec);
|
||||||
|
|
||||||
void save(void *ptr) const {
|
void save(void* ptr) const {
|
||||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void save(void *ptr, const int elem_num) const {
|
void save(void* ptr, const int elem_num) const {
|
||||||
int full_blocks = elem_num / 8;
|
int full_blocks = elem_num / 8;
|
||||||
int remainder = elem_num % 8;
|
int remainder = elem_num % 8;
|
||||||
|
|
||||||
@ -106,8 +106,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
|||||||
if (remainder > 0) {
|
if (remainder > 0) {
|
||||||
float16x8_t temp = reg.val[full_blocks];
|
float16x8_t temp = reg.val[full_blocks];
|
||||||
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
|
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
|
||||||
switch (remainder)
|
switch (remainder) {
|
||||||
{
|
|
||||||
case 1:
|
case 1:
|
||||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||||
break;
|
break;
|
||||||
@ -158,23 +157,23 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#ifdef ARM_BF16_SUPPORT
|
#ifdef ARM_BF16_SUPPORT
|
||||||
struct BF16Vec8 : public Vec<BF16Vec8> {
|
struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||||
constexpr static int VEC_ELEM_NUM = 8;
|
constexpr static int VEC_ELEM_NUM = 8;
|
||||||
|
|
||||||
bfloat16x8_t reg;
|
bfloat16x8_t reg;
|
||||||
|
|
||||||
explicit BF16Vec8(const void *ptr)
|
explicit BF16Vec8(const void* ptr)
|
||||||
: reg(*reinterpret_cast<const bfloat16x8_t *>(ptr)) {};
|
: reg(*reinterpret_cast<const bfloat16x8_t*>(ptr)) {};
|
||||||
|
|
||||||
explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
|
explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
|
||||||
|
|
||||||
explicit BF16Vec8(const FP32Vec8 &);
|
explicit BF16Vec8(const FP32Vec8&);
|
||||||
|
|
||||||
explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
|
explicit BF16Vec8(float32x4x2_t v)
|
||||||
|
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8_t *>(ptr) = reg; }
|
void save(void* ptr) const { *reinterpret_cast<bfloat16x8_t*>(ptr) = reg; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||||
@ -182,19 +181,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
|||||||
|
|
||||||
bfloat16x8x2_t reg;
|
bfloat16x8x2_t reg;
|
||||||
|
|
||||||
explicit BF16Vec16(const void *ptr)
|
explicit BF16Vec16(const void* ptr)
|
||||||
: reg(*reinterpret_cast<const bfloat16x8x2_t *>(ptr)) {};
|
: reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {};
|
||||||
|
|
||||||
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
|
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
|
||||||
|
|
||||||
explicit BF16Vec16(const FP32Vec16 &);
|
explicit BF16Vec16(const FP32Vec16&);
|
||||||
|
|
||||||
explicit BF16Vec16(float32x4x4_t v) : reg({
|
explicit BF16Vec16(float32x4x4_t v)
|
||||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
|
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
|
||||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])
|
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||||
}){};
|
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x2_t *>(ptr) = reg; };
|
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||||
};
|
};
|
||||||
|
|
||||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||||
@ -202,19 +200,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
|||||||
|
|
||||||
bfloat16x8x4_t reg;
|
bfloat16x8x4_t reg;
|
||||||
|
|
||||||
explicit BF16Vec32(const void *ptr)
|
explicit BF16Vec32(const void* ptr)
|
||||||
: reg(*reinterpret_cast<const bfloat16x8x4_t *>(ptr)) {};
|
: reg(*reinterpret_cast<const bfloat16x8x4_t*>(ptr)) {};
|
||||||
|
|
||||||
explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
|
explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
|
||||||
|
|
||||||
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
|
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||||
vec8_data.reg,
|
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||||
vec8_data.reg,
|
|
||||||
vec8_data.reg,
|
|
||||||
vec8_data.reg
|
|
||||||
}) {};
|
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x4_t *>(ptr) = reg; };
|
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -232,11 +226,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
|||||||
|
|
||||||
explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
|
explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
|
||||||
|
|
||||||
explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {};
|
explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {};
|
||||||
|
|
||||||
explicit FP32Vec4(float32x4_t data) : reg(data) {};
|
explicit FP32Vec4(float32x4_t data) : reg(data) {};
|
||||||
|
|
||||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {};
|
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||||
@ -252,32 +246,37 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
|
|
||||||
explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
|
explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
|
||||||
|
|
||||||
explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
|
explicit FP32Vec8(const float* ptr)
|
||||||
|
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
|
||||||
|
|
||||||
explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
|
explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
|
||||||
|
|
||||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {};
|
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
|
||||||
|
|
||||||
explicit FP32Vec8(const FP16Vec8 &v) {
|
explicit FP32Vec8(const FP16Vec8& v) {
|
||||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
|
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
|
||||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
|
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
|
explicit FP32Vec8(float16x8_t v)
|
||||||
|
: reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
|
||||||
|
|
||||||
#ifdef ARM_BF16_SUPPORT
|
#ifdef ARM_BF16_SUPPORT
|
||||||
|
|
||||||
explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
|
explicit FP32Vec8(bfloat16x8_t v)
|
||||||
|
: reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
|
||||||
|
|
||||||
explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
|
explicit FP32Vec8(const BF16Vec8& v)
|
||||||
|
: reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
float reduce_sum() const {
|
float reduce_sum() const {
|
||||||
AliasReg ar;
|
AliasReg ar;
|
||||||
ar.reg = reg;
|
ar.reg = reg;
|
||||||
float answer = 0;
|
float answer = 0;
|
||||||
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
|
unroll_loop<int, VEC_ELEM_NUM>(
|
||||||
|
[&answer, &ar](int i) { answer += ar.values[i]; });
|
||||||
|
|
||||||
return answer;
|
return answer;
|
||||||
}
|
}
|
||||||
@ -324,10 +323,14 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
AliasReg ar;
|
AliasReg ar;
|
||||||
ar.reg = reg;
|
ar.reg = reg;
|
||||||
|
|
||||||
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])), static_cast<float32_t>(erf(ar.values[1]))};
|
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])),
|
||||||
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])), static_cast<float32_t>(erf(ar.values[3]))};
|
static_cast<float32_t>(erf(ar.values[1]))};
|
||||||
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])), static_cast<float32_t>(erf(ar.values[5]))};
|
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])),
|
||||||
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])), static_cast<float32_t>(erf(ar.values[7]))};
|
static_cast<float32_t>(erf(ar.values[3]))};
|
||||||
|
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])),
|
||||||
|
static_cast<float32_t>(erf(ar.values[5]))};
|
||||||
|
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])),
|
||||||
|
static_cast<float32_t>(erf(ar.values[7]))};
|
||||||
|
|
||||||
float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
|
float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
|
||||||
float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
|
float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
|
||||||
@ -339,23 +342,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
return FP32Vec8(result);
|
return FP32Vec8(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||||
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])}));
|
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||||
|
vmulq_f32(reg.val[1], b.reg.val[1])}));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||||
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])}));
|
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||||
|
vaddq_f32(reg.val[1], b.reg.val[1])}));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||||
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])}));
|
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||||
|
vsubq_f32(reg.val[1], b.reg.val[1])}));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||||
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])}));
|
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||||
|
vdivq_f32(reg.val[1], b.reg.val[1])}));
|
||||||
}
|
}
|
||||||
|
|
||||||
void save(float *ptr) const {
|
void save(float* ptr) const {
|
||||||
vst1q_f32(ptr, reg.val[0]);
|
vst1q_f32(ptr, reg.val[0]);
|
||||||
vst1q_f32(ptr + 4, reg.val[1]);
|
vst1q_f32(ptr + 4, reg.val[1]);
|
||||||
}
|
}
|
||||||
@ -370,103 +377,100 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
|
|
||||||
float32x4x4_t reg;
|
float32x4x4_t reg;
|
||||||
|
|
||||||
explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
|
explicit FP32Vec16(float v)
|
||||||
|
: reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
|
||||||
|
|
||||||
explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}
|
explicit FP32Vec16()
|
||||||
|
: reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0),
|
||||||
|
vmovq_n_f32(0.0)}) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {}
|
explicit FP32Vec16(const float* ptr)
|
||||||
|
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
|
||||||
|
vld1q_f32(ptr + 12)}) {}
|
||||||
|
|
||||||
explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
|
explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec8 &data) {
|
explicit FP32Vec16(const FP32Vec8& data) {
|
||||||
reg.val[0] = data.reg.val[0];
|
reg.val[0] = data.reg.val[0];
|
||||||
reg.val[1] = data.reg.val[1];
|
reg.val[1] = data.reg.val[1];
|
||||||
reg.val[2] = data.reg.val[0];
|
reg.val[2] = data.reg.val[0];
|
||||||
reg.val[3] = data.reg.val[1];
|
reg.val[3] = data.reg.val[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
|
explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {}
|
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {}
|
||||||
|
|
||||||
#ifdef ARM_BF16_SUPPORT
|
#ifdef ARM_BF16_SUPPORT
|
||||||
explicit FP32Vec16(bfloat16x8x2_t v) : reg({
|
explicit FP32Vec16(bfloat16x8x2_t v)
|
||||||
vcvtq_low_f32_bf16(v.val[0]),
|
: reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]),
|
||||||
vcvtq_high_f32_bf16(v.val[0]),
|
vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {};
|
||||||
vcvtq_low_f32_bf16(v.val[1]),
|
#endif
|
||||||
vcvtq_high_f32_bf16(v.val[1])
|
|
||||||
}) {};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec4 &data) {
|
explicit FP32Vec16(const FP32Vec4& data) {
|
||||||
reg.val[0] = data.reg;
|
reg.val[0] = data.reg;
|
||||||
reg.val[1] = data.reg;
|
reg.val[1] = data.reg;
|
||||||
reg.val[2] = data.reg;
|
reg.val[2] = data.reg;
|
||||||
reg.val[3] = data.reg;
|
reg.val[3] = data.reg;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef ARM_BF16_SUPPORT
|
#ifdef ARM_BF16_SUPPORT
|
||||||
explicit FP32Vec16(const BF16Vec16 &v) : reg({
|
explicit FP32Vec16(const BF16Vec16& v)
|
||||||
vcvtq_low_f32_bf16(v.reg.val[0]),
|
: reg({vcvtq_low_f32_bf16(v.reg.val[0]),
|
||||||
vcvtq_high_f32_bf16(v.reg.val[0]),
|
vcvtq_high_f32_bf16(v.reg.val[0]),
|
||||||
vcvtq_low_f32_bf16(v.reg.val[1]),
|
vcvtq_low_f32_bf16(v.reg.val[1]),
|
||||||
vcvtq_high_f32_bf16(v.reg.val[1])
|
vcvtq_high_f32_bf16(v.reg.val[1])}) {};
|
||||||
}) {};
|
|
||||||
|
|
||||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {};
|
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
explicit FP32Vec16(const FP16Vec16 &v) {
|
explicit FP32Vec16(const FP16Vec16& v) {
|
||||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
|
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
|
||||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
|
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
|
||||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||||
};
|
};
|
||||||
|
|
||||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(float32x4x4_t({
|
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||||
vaddq_f32(reg.val[0], b.reg.val[0]),
|
|
||||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||||
vaddq_f32(reg.val[2], b.reg.val[2]),
|
vaddq_f32(reg.val[2], b.reg.val[2]),
|
||||||
vaddq_f32(reg.val[3], b.reg.val[3])}));
|
vaddq_f32(reg.val[3], b.reg.val[3])}));
|
||||||
};
|
};
|
||||||
|
|
||||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(float32x4x4_t({
|
return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||||
vmulq_f32(reg.val[0], b.reg.val[0]),
|
|
||||||
vmulq_f32(reg.val[1], b.reg.val[1]),
|
vmulq_f32(reg.val[1], b.reg.val[1]),
|
||||||
vmulq_f32(reg.val[2], b.reg.val[2]),
|
vmulq_f32(reg.val[2], b.reg.val[2]),
|
||||||
vmulq_f32(reg.val[3], b.reg.val[3])}));
|
vmulq_f32(reg.val[3], b.reg.val[3])}));
|
||||||
};
|
};
|
||||||
|
|
||||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(float32x4x4_t({
|
return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||||
vsubq_f32(reg.val[0], b.reg.val[0]),
|
|
||||||
vsubq_f32(reg.val[1], b.reg.val[1]),
|
vsubq_f32(reg.val[1], b.reg.val[1]),
|
||||||
vsubq_f32(reg.val[2], b.reg.val[2]),
|
vsubq_f32(reg.val[2], b.reg.val[2]),
|
||||||
vsubq_f32(reg.val[3], b.reg.val[3])
|
vsubq_f32(reg.val[3], b.reg.val[3])}));
|
||||||
}));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(float32x4x4_t({
|
return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||||
vdivq_f32(reg.val[0], b.reg.val[0]),
|
|
||||||
vdivq_f32(reg.val[1], b.reg.val[1]),
|
vdivq_f32(reg.val[1], b.reg.val[1]),
|
||||||
vdivq_f32(reg.val[2], b.reg.val[2]),
|
vdivq_f32(reg.val[2], b.reg.val[2]),
|
||||||
vdivq_f32(reg.val[3], b.reg.val[3])
|
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||||
}));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
float reduce_sum() const {
|
float reduce_sum() const {
|
||||||
AliasReg ar;
|
AliasReg ar;
|
||||||
ar.reg = reg;
|
ar.reg = reg;
|
||||||
float answer = 0;
|
float answer = 0;
|
||||||
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
|
unroll_loop<int, VEC_ELEM_NUM>(
|
||||||
|
[&answer, &ar](int i) { answer += ar.values[i]; });
|
||||||
|
|
||||||
return answer;
|
return answer;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int group_size> float reduce_sub_sum(int idx) {
|
template <int group_size>
|
||||||
|
float reduce_sub_sum(int idx) {
|
||||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||||
|
|
||||||
AliasReg ar;
|
AliasReg ar;
|
||||||
@ -479,7 +483,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
return answer;
|
return answer;
|
||||||
};
|
};
|
||||||
|
|
||||||
void save(float *ptr) const {
|
void save(float* ptr) const {
|
||||||
vst1q_f32(ptr, reg.val[0]);
|
vst1q_f32(ptr, reg.val[0]);
|
||||||
vst1q_f32(ptr + 4, reg.val[1]);
|
vst1q_f32(ptr + 4, reg.val[1]);
|
||||||
vst1q_f32(ptr + 8, reg.val[2]);
|
vst1q_f32(ptr + 8, reg.val[2]);
|
||||||
@ -487,25 +491,42 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T> struct VecType { using vec_type = void; };
|
template <typename T>
|
||||||
|
struct VecType {
|
||||||
|
using vec_type = void;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
template <typename T>
|
||||||
|
using vec_t = typename VecType<T>::vec_type;
|
||||||
|
|
||||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
template <>
|
||||||
|
struct VecType<float> {
|
||||||
|
using vec_type = FP32Vec8;
|
||||||
|
};
|
||||||
|
|
||||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
|
template <>
|
||||||
|
struct VecType<c10::Half> {
|
||||||
|
using vec_type = FP16Vec8;
|
||||||
|
};
|
||||||
|
|
||||||
#ifdef ARM_BF16_SUPPORT
|
#ifdef ARM_BF16_SUPPORT
|
||||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
template <>
|
||||||
|
struct VecType<c10::BFloat16> {
|
||||||
|
using vec_type = BF16Vec8;
|
||||||
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
template <typename T>
|
||||||
|
void storeFP32(float v, T* ptr) {
|
||||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
*ptr = v;
|
||||||
*reinterpret_cast<__fp16 *>(ptr) = v;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) {
|
template <>
|
||||||
|
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
|
||||||
|
*reinterpret_cast<__fp16*>(ptr) = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
|
||||||
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
|
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
|
||||||
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
|
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
|
||||||
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
|
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
|
||||||
@ -515,15 +536,14 @@ inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) {
|
|||||||
reg.val[1] = vcombine_f16(low_1, high_1);
|
reg.val[1] = vcombine_f16(low_1, high_1);
|
||||||
};
|
};
|
||||||
|
|
||||||
inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) {
|
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
|
||||||
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
|
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
|
||||||
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
|
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
|
||||||
|
|
||||||
reg = vcombine_f16(lower_half, upper_half);
|
reg = vcombine_f16(lower_half, upper_half);
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||||
|
|
||||||
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
|
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
|
||||||
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
|
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
|
||||||
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
|
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
|
||||||
@ -531,8 +551,7 @@ inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#ifdef ARM_BF16_SUPPORT
|
#ifdef ARM_BF16_SUPPORT
|
||||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
|
||||||
|
|
||||||
float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
|
float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
|
||||||
float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
|
float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
|
||||||
float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
|
float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
|
||||||
@ -551,22 +570,22 @@ inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef ARM_BF16_SUPPORT
|
#ifdef ARM_BF16_SUPPORT
|
||||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {};
|
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||||
|
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {
|
||||||
|
};
|
||||||
|
|
||||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({
|
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
|
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
|
||||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3])
|
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]),
|
||||||
}){};
|
v.reg.val[3])}) {};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
inline void prefetch(const void *addr) {
|
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); };
|
||||||
__builtin_prefetch(addr, 0, 1);
|
|
||||||
};
|
|
||||||
|
|
||||||
#ifdef ARM_BF16_SUPPORT
|
#ifdef ARM_BF16_SUPPORT
|
||||||
template <>
|
template <>
|
||||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||||
*reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v);
|
*reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v);
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
};
|
}; // namespace vec_op
|
@ -17,30 +17,32 @@ namespace vec_op {
|
|||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#ifndef CPU_OP_GUARD
|
#ifndef CPU_OP_GUARD
|
||||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||||
#else
|
#else
|
||||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||||
std::cout << #NAME << " invoked." << std::endl;
|
std::cout << #NAME << " invoked." << std::endl;
|
||||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||||
|
std::cout << #NAME << " exit." << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename T, T... indexes, typename F>
|
template <typename T, T... indexes, typename F>
|
||||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||||
(f(std::integral_constant<T, indexes>{}), ...);
|
(f(std::integral_constant<T, indexes>{}), ...);
|
||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
template <typename T, T count, typename F,
|
template <typename T, T count, typename F,
|
||||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||||
constexpr void unroll_loop(F &&f) {
|
constexpr void unroll_loop(F&& f) {
|
||||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> struct Vec {
|
template <typename T>
|
||||||
|
struct Vec {
|
||||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -68,12 +70,14 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
|||||||
|
|
||||||
__vector signed short reg;
|
__vector signed short reg;
|
||||||
|
|
||||||
explicit BF16Vec8(const void *ptr)
|
explicit BF16Vec8(const void* ptr)
|
||||||
: reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {}
|
: reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {}
|
||||||
|
|
||||||
explicit BF16Vec8(const FP32Vec8 &);
|
explicit BF16Vec8(const FP32Vec8&);
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; }
|
void save(void* ptr) const {
|
||||||
|
*reinterpret_cast<__vector signed short*>(ptr) = reg;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||||
@ -81,18 +85,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
|||||||
|
|
||||||
ss16x8x2_t reg;
|
ss16x8x2_t reg;
|
||||||
|
|
||||||
explicit BF16Vec16(const void *ptr) {
|
explicit BF16Vec16(const void* ptr) {
|
||||||
// Load 256 bits in two parts
|
// Load 256 bits in two parts
|
||||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr);
|
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
|
||||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr);
|
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit BF16Vec16(const FP32Vec16 &);
|
explicit BF16Vec16(const FP32Vec16&);
|
||||||
|
|
||||||
void save(void *ptr) const {
|
void save(void* ptr) const {
|
||||||
// Save 256 bits in two parts
|
// Save 256 bits in two parts
|
||||||
vec_xst(reg.val[0], 0, (signed short *)ptr);
|
vec_xst(reg.val[0], 0, (signed short*)ptr);
|
||||||
vec_xst(reg.val[1], 16, (signed short *)ptr);
|
vec_xst(reg.val[1], 16, (signed short*)ptr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -102,19 +106,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
|||||||
constexpr static int VEC_ELEM_NUM = 32;
|
constexpr static int VEC_ELEM_NUM = 32;
|
||||||
|
|
||||||
ss16x8x4_t reg;
|
ss16x8x4_t reg;
|
||||||
explicit BF16Vec32(const void *ptr)
|
explicit BF16Vec32(const void* ptr)
|
||||||
: reg(*reinterpret_cast<const ss16x8x4_t *>(ptr)) {}
|
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
|
||||||
|
|
||||||
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
|
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
|
||||||
|
|
||||||
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
|
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||||
vec8_data.reg,
|
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
|
||||||
vec8_data.reg,
|
|
||||||
vec8_data.reg,
|
|
||||||
vec8_data.reg
|
|
||||||
}) {}
|
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<ss16x8x4_t *>(ptr) = reg; }
|
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FP32Vec4 : public Vec<FP32Vec4> {
|
struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||||
@ -130,11 +130,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
|||||||
|
|
||||||
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
|
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
|
||||||
|
|
||||||
explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {}
|
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
|
||||||
|
|
||||||
explicit FP32Vec4(__vector float data) : reg(data) {}
|
explicit FP32Vec4(__vector float data) : reg(data) {}
|
||||||
|
|
||||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||||
@ -156,19 +156,19 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
reg.val[1] = vec_splats(0.0f);
|
reg.val[1] = vec_splats(0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec8(const float *ptr) {
|
explicit FP32Vec8(const float* ptr) {
|
||||||
reg.val[0] = vec_xl(0, ptr);
|
reg.val[0] = vec_xl(0, ptr);
|
||||||
reg.val[1] = vec_xl(16, ptr);
|
reg.val[1] = vec_xl(16, ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
|
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
|
||||||
|
|
||||||
explicit FP32Vec8(const FP32Vec8 &data) {
|
explicit FP32Vec8(const FP32Vec8& data) {
|
||||||
reg.val[0] = data.reg.val[0];
|
reg.val[0] = data.reg.val[0];
|
||||||
reg.val[1] = data.reg.val[1];
|
reg.val[1] = data.reg.val[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec8(const BF16Vec8 &v) {
|
explicit FP32Vec8(const BF16Vec8& v) {
|
||||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
|
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
|
||||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
|
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
|
||||||
}
|
}
|
||||||
@ -177,7 +177,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
AliasReg ar;
|
AliasReg ar;
|
||||||
ar.reg = reg;
|
ar.reg = reg;
|
||||||
float result = 0;
|
float result = 0;
|
||||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
unroll_loop<int, VEC_ELEM_NUM>(
|
||||||
|
[&result, &ar](int i) { result += ar.values[i]; });
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -230,23 +231,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||||
return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
return FP32Vec8(
|
||||||
|
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||||
return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
return FP32Vec8(
|
||||||
|
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||||
return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
return FP32Vec8(
|
||||||
|
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||||
return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
return FP32Vec8(
|
||||||
|
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
||||||
}
|
}
|
||||||
|
|
||||||
void save(float *ptr) const {
|
void save(float* ptr) const {
|
||||||
vec_xst(reg.val[0], 0, ptr);
|
vec_xst(reg.val[0], 0, ptr);
|
||||||
vec_xst(reg.val[1], 16, ptr);
|
vec_xst(reg.val[1], 16, ptr);
|
||||||
}
|
}
|
||||||
@ -275,7 +280,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
reg.val[3] = vec_splats(0.0f);
|
reg.val[3] = vec_splats(0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec16(const float *ptr) {
|
explicit FP32Vec16(const float* ptr) {
|
||||||
reg.val[0] = vec_xl(0, ptr);
|
reg.val[0] = vec_xl(0, ptr);
|
||||||
reg.val[1] = vec_xl(16, ptr);
|
reg.val[1] = vec_xl(16, ptr);
|
||||||
reg.val[2] = vec_xl(32, ptr);
|
reg.val[2] = vec_xl(32, ptr);
|
||||||
@ -284,63 +289,59 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
|
|
||||||
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
|
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec16 &data) {
|
explicit FP32Vec16(const FP32Vec16& data) {
|
||||||
reg.val[0] = data.reg.val[0];
|
reg.val[0] = data.reg.val[0];
|
||||||
reg.val[1] = data.reg.val[1];
|
reg.val[1] = data.reg.val[1];
|
||||||
reg.val[2] = data.reg.val[2];
|
reg.val[2] = data.reg.val[2];
|
||||||
reg.val[3] = data.reg.val[3];
|
reg.val[3] = data.reg.val[3];
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec4 &data) {
|
explicit FP32Vec16(const FP32Vec4& data) {
|
||||||
reg.val[0] = data.reg;
|
reg.val[0] = data.reg;
|
||||||
reg.val[1] = data.reg;
|
reg.val[1] = data.reg;
|
||||||
reg.val[2] = data.reg;
|
reg.val[2] = data.reg;
|
||||||
reg.val[3] = data.reg;
|
reg.val[3] = data.reg;
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec8 &data) {
|
explicit FP32Vec16(const FP32Vec8& data) {
|
||||||
reg.val[0] = data.reg.val[0];
|
reg.val[0] = data.reg.val[0];
|
||||||
reg.val[1] = data.reg.val[1];
|
reg.val[1] = data.reg.val[1];
|
||||||
reg.val[2] = data.reg.val[0];
|
reg.val[2] = data.reg.val[0];
|
||||||
reg.val[3] = data.reg.val[1];
|
reg.val[3] = data.reg.val[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec16(const BF16Vec16 &v) {
|
explicit FP32Vec16(const BF16Vec16& v) {
|
||||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
|
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
|
||||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
|
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
|
||||||
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
|
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
|
||||||
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
|
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||||
|
|
||||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(f32x4x4_t({
|
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
|
||||||
vec_mul(reg.val[0], b.reg.val[0]),
|
|
||||||
vec_mul(reg.val[1], b.reg.val[1]),
|
vec_mul(reg.val[1], b.reg.val[1]),
|
||||||
vec_mul(reg.val[2], b.reg.val[2]),
|
vec_mul(reg.val[2], b.reg.val[2]),
|
||||||
vec_mul(reg.val[3], b.reg.val[3])}));
|
vec_mul(reg.val[3], b.reg.val[3])}));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(f32x4x4_t({
|
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
|
||||||
vec_add(reg.val[0], b.reg.val[0]),
|
|
||||||
vec_add(reg.val[1], b.reg.val[1]),
|
vec_add(reg.val[1], b.reg.val[1]),
|
||||||
vec_add(reg.val[2], b.reg.val[2]),
|
vec_add(reg.val[2], b.reg.val[2]),
|
||||||
vec_add(reg.val[3], b.reg.val[3])}));
|
vec_add(reg.val[3], b.reg.val[3])}));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(f32x4x4_t({
|
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
|
||||||
vec_sub(reg.val[0], b.reg.val[0]),
|
|
||||||
vec_sub(reg.val[1], b.reg.val[1]),
|
vec_sub(reg.val[1], b.reg.val[1]),
|
||||||
vec_sub(reg.val[2], b.reg.val[2]),
|
vec_sub(reg.val[2], b.reg.val[2]),
|
||||||
vec_sub(reg.val[3], b.reg.val[3])}));
|
vec_sub(reg.val[3], b.reg.val[3])}));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(f32x4x4_t({
|
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
|
||||||
vec_div(reg.val[0], b.reg.val[0]),
|
|
||||||
vec_div(reg.val[1], b.reg.val[1]),
|
vec_div(reg.val[1], b.reg.val[1]),
|
||||||
vec_div(reg.val[2], b.reg.val[2]),
|
vec_div(reg.val[2], b.reg.val[2]),
|
||||||
vec_div(reg.val[3], b.reg.val[3])}));
|
vec_div(reg.val[3], b.reg.val[3])}));
|
||||||
@ -350,12 +351,14 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
AliasReg ar;
|
AliasReg ar;
|
||||||
ar.reg = reg;
|
ar.reg = reg;
|
||||||
float result = 0;
|
float result = 0;
|
||||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
unroll_loop<int, VEC_ELEM_NUM>(
|
||||||
|
[&result, &ar](int i) { result += ar.values[i]; });
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int group_size> float reduce_sub_sum(int idx) {
|
template <int group_size>
|
||||||
|
float reduce_sub_sum(int idx) {
|
||||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||||
|
|
||||||
AliasReg ar;
|
AliasReg ar;
|
||||||
@ -368,7 +371,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void save(float *ptr) const {
|
void save(float* ptr) const {
|
||||||
vec_xst(reg.val[0], 0, ptr);
|
vec_xst(reg.val[0], 0, ptr);
|
||||||
vec_xst(reg.val[1], 16, ptr);
|
vec_xst(reg.val[1], 16, ptr);
|
||||||
vec_xst(reg.val[2], 32, ptr);
|
vec_xst(reg.val[2], 32, ptr);
|
||||||
@ -376,43 +379,62 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T> struct VecType { using vec_type = void; };
|
template <typename T>
|
||||||
|
struct VecType {
|
||||||
|
using vec_type = void;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
template <typename T>
|
||||||
|
using vec_t = typename VecType<T>::vec_type;
|
||||||
|
|
||||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
template <>
|
||||||
|
struct VecType<float> {
|
||||||
|
using vec_type = FP32Vec8;
|
||||||
|
};
|
||||||
|
|
||||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
template <>
|
||||||
|
struct VecType<c10::BFloat16> {
|
||||||
|
using vec_type = BF16Vec8;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
template <typename T>
|
||||||
|
void storeFP32(float v, T* ptr) {
|
||||||
|
*ptr = v;
|
||||||
|
}
|
||||||
|
|
||||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||||
acc = acc + a * b;
|
acc = acc + a * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
template <>
|
||||||
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||||
reinterpret_cast<c10::BFloat16 *>(&v);
|
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||||
|
reinterpret_cast<c10::BFloat16*>(&v);
|
||||||
*ptr = *(v_ptr + 1);
|
*ptr = *(v_ptr + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef __VEC_CLASS_FP_NAN
|
#ifndef __VEC_CLASS_FP_NAN
|
||||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 };
|
const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13,
|
||||||
|
16, 17, 20, 21, 24, 25, 28, 29};
|
||||||
#ifndef _ARCH_PWR10
|
#ifndef _ARCH_PWR10
|
||||||
const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff };
|
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
|
||||||
const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 };
|
0x00007fff};
|
||||||
const static __vector unsigned int sh16 = { 16, 16, 16, 16 };
|
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
|
||||||
const static __vector unsigned int one = { 1, 1, 1, 1 };
|
0x7fc00000};
|
||||||
|
const static __vector unsigned int sh16 = {16, 16, 16, 16};
|
||||||
|
const static __vector unsigned int one = {1, 1, 1, 1};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
|
||||||
#ifdef _ARCH_PWR10
|
#ifdef _ARCH_PWR10
|
||||||
__vector signed short ret[2];
|
__vector signed short ret[2];
|
||||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
|
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
|
(__vector unsigned char)v.reg.val[0]);
|
||||||
|
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||||
|
(__vector unsigned char)v.reg.val[1]);
|
||||||
reg = vec_perm(ret[0], ret[1], omask);
|
reg = vec_perm(ret[0], ret[1], omask);
|
||||||
#elif defined(_ARCH_PWR9)
|
#elif defined(_ARCH_PWR9)
|
||||||
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
|
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
|
||||||
@ -425,8 +447,10 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
|||||||
__vector unsigned int rnd1 = vec_add(lsb1, bias);
|
__vector unsigned int rnd1 = vec_add(lsb1, bias);
|
||||||
inp0 = vec_add(inp0, rnd0);
|
inp0 = vec_add(inp0, rnd0);
|
||||||
inp1 = vec_add(inp1, rnd1);
|
inp1 = vec_add(inp1, rnd1);
|
||||||
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
__vector __bool int sel0 =
|
||||||
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||||
|
__vector __bool int sel1 =
|
||||||
|
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||||
inp0 = vec_sel(inp0, nan, sel0);
|
inp0 = vec_sel(inp0, nan, sel0);
|
||||||
inp1 = vec_sel(inp1, nan, sel1);
|
inp1 = vec_sel(inp1, nan, sel1);
|
||||||
inp0 = vec_sr(inp0, sh16);
|
inp0 = vec_sr(inp0, sh16);
|
||||||
@ -435,13 +459,17 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||||
#ifdef _ARCH_PWR10
|
#ifdef _ARCH_PWR10
|
||||||
__vector signed short ret[4];
|
__vector signed short ret[4];
|
||||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
|
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
|
(__vector unsigned char)v.reg.val[0]);
|
||||||
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]);
|
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||||
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]);
|
(__vector unsigned char)v.reg.val[1]);
|
||||||
|
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||||
|
(__vector unsigned char)v.reg.val[2]);
|
||||||
|
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||||
|
(__vector unsigned char)v.reg.val[3]);
|
||||||
reg.val[0] = vec_perm(ret[0], ret[1], omask);
|
reg.val[0] = vec_perm(ret[0], ret[1], omask);
|
||||||
reg.val[1] = vec_perm(ret[2], ret[3], omask);
|
reg.val[1] = vec_perm(ret[2], ret[3], omask);
|
||||||
#elif defined(_ARCH_PWR9)
|
#elif defined(_ARCH_PWR9)
|
||||||
@ -465,10 +493,14 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
|||||||
inp1 = vec_add(inp1, rnd1);
|
inp1 = vec_add(inp1, rnd1);
|
||||||
inp2 = vec_add(inp2, rnd2);
|
inp2 = vec_add(inp2, rnd2);
|
||||||
inp3 = vec_add(inp3, rnd3);
|
inp3 = vec_add(inp3, rnd3);
|
||||||
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
__vector __bool int sel0 =
|
||||||
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||||
__vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
|
__vector __bool int sel1 =
|
||||||
__vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
|
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||||
|
__vector __bool int sel2 =
|
||||||
|
vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
|
||||||
|
__vector __bool int sel3 =
|
||||||
|
vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
|
||||||
inp0 = vec_sel(inp0, nan, sel0);
|
inp0 = vec_sel(inp0, nan, sel0);
|
||||||
inp1 = vec_sel(inp1, nan, sel1);
|
inp1 = vec_sel(inp1, nan, sel1);
|
||||||
inp2 = vec_sel(inp2, nan, sel2);
|
inp2 = vec_sel(inp2, nan, sel2);
|
||||||
@ -482,7 +514,7 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void prefetch(const void *addr) {
|
inline void prefetch(const void* addr) {
|
||||||
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
|
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,30 +20,31 @@ namespace vec_op {
|
|||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#ifndef CPU_OP_GUARD
|
#ifndef CPU_OP_GUARD
|
||||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||||
#else
|
#else
|
||||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||||
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
||||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename T, T... indexes, typename F>
|
template <typename T, T... indexes, typename F>
|
||||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||||
(f(std::integral_constant<T, indexes>{}), ...);
|
(f(std::integral_constant<T, indexes>{}), ...);
|
||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
template <typename T, T count, typename F,
|
template <typename T, T count, typename F,
|
||||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||||
constexpr void unroll_loop(F &&f) {
|
constexpr void unroll_loop(F&& f) {
|
||||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> struct Vec {
|
template <typename T>
|
||||||
|
struct Vec {
|
||||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -55,12 +56,12 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
|
|||||||
|
|
||||||
__m128i reg;
|
__m128i reg;
|
||||||
|
|
||||||
explicit FP16Vec8(const void *ptr)
|
explicit FP16Vec8(const void* ptr)
|
||||||
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
|
||||||
|
|
||||||
explicit FP16Vec8(const FP32Vec8 &);
|
explicit FP16Vec8(const FP32Vec8&);
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FP16Vec16 : public Vec<FP16Vec16> {
|
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||||
@ -68,12 +69,12 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
|||||||
|
|
||||||
__m256i reg;
|
__m256i reg;
|
||||||
|
|
||||||
explicit FP16Vec16(const void *ptr)
|
explicit FP16Vec16(const void* ptr)
|
||||||
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||||
|
|
||||||
explicit FP16Vec16(const FP32Vec16 &);
|
explicit FP16Vec16(const FP32Vec16&);
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||||
|
|
||||||
void save(void* ptr, const int elem_num) const {
|
void save(void* ptr, const int elem_num) const {
|
||||||
constexpr uint32_t M = 0xFFFFFFFF;
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
@ -87,12 +88,12 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
|||||||
|
|
||||||
__m128i reg;
|
__m128i reg;
|
||||||
|
|
||||||
explicit BF16Vec8(const void *ptr)
|
explicit BF16Vec8(const void* ptr)
|
||||||
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
|
||||||
|
|
||||||
explicit BF16Vec8(const FP32Vec8 &);
|
explicit BF16Vec8(const FP32Vec8&);
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||||
@ -100,12 +101,12 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
|||||||
|
|
||||||
__m256i reg;
|
__m256i reg;
|
||||||
|
|
||||||
explicit BF16Vec16(const void *ptr)
|
explicit BF16Vec16(const void* ptr)
|
||||||
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||||
|
|
||||||
explicit BF16Vec16(const FP32Vec16 &);
|
explicit BF16Vec16(const FP32Vec16&);
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||||
|
|
||||||
void save(void* ptr, const int elem_num) const {
|
void save(void* ptr, const int elem_num) const {
|
||||||
constexpr uint32_t M = 0xFFFFFFFF;
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
@ -120,11 +121,11 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
|||||||
|
|
||||||
__m512i reg;
|
__m512i reg;
|
||||||
|
|
||||||
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||||
|
|
||||||
explicit BF16Vec32(__m512i data) : reg(data) {}
|
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||||
|
|
||||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
explicit BF16Vec32(BF16Vec8& vec8_data)
|
||||||
: reg((__m512i)_mm512_inserti32x4(
|
: reg((__m512i)_mm512_inserti32x4(
|
||||||
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
|
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
|
||||||
(__m128i)vec8_data.reg),
|
(__m128i)vec8_data.reg),
|
||||||
@ -132,7 +133,7 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
|||||||
(__m128i)vec8_data.reg, 2),
|
(__m128i)vec8_data.reg, 2),
|
||||||
(__m128i)vec8_data.reg, 3)) {}
|
(__m128i)vec8_data.reg, 3)) {}
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
|
void save(void* ptr) const { *reinterpret_cast<__m512i*>(ptr) = reg; }
|
||||||
};
|
};
|
||||||
#else
|
#else
|
||||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||||
@ -141,14 +142,14 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
|||||||
__m256i reg_low;
|
__m256i reg_low;
|
||||||
__m256i reg_high;
|
__m256i reg_high;
|
||||||
|
|
||||||
explicit BF16Vec32(const void *ptr)
|
explicit BF16Vec32(const void* ptr)
|
||||||
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
|
: reg_low(_mm256_loadu_si256((__m256i const*)ptr)),
|
||||||
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
|
reg_high(_mm256_loadu_si256((__m256i const*)ptr + 1)) {}
|
||||||
|
|
||||||
explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
|
explicit BF16Vec32(__m256i low, __m256i high)
|
||||||
reg_high(high) {}
|
: reg_low(low), reg_high(high) {}
|
||||||
|
|
||||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
explicit BF16Vec32(BF16Vec8& vec8_data)
|
||||||
: reg_low((__m256i)_mm256_inserti32x4(
|
: reg_low((__m256i)_mm256_inserti32x4(
|
||||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||||
(__m128i)vec8_data.reg, 1)),
|
(__m128i)vec8_data.reg, 1)),
|
||||||
@ -156,9 +157,9 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
|||||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||||
(__m128i)vec8_data.reg, 1)) {}
|
(__m128i)vec8_data.reg, 1)) {}
|
||||||
|
|
||||||
void save(void *ptr) const {
|
void save(void* ptr) const {
|
||||||
*reinterpret_cast<__m256i *>(ptr) = reg_low;
|
*reinterpret_cast<__m256i*>(ptr) = reg_low;
|
||||||
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
|
*reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
@ -176,11 +177,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
|||||||
|
|
||||||
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
|
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
|
||||||
|
|
||||||
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
|
explicit FP32Vec4(const float* ptr) : reg(_mm_loadu_ps(ptr)) {}
|
||||||
|
|
||||||
explicit FP32Vec4(__m128 data) : reg(data) {}
|
explicit FP32Vec4(__m128 data) : reg(data) {}
|
||||||
|
|
||||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||||
@ -196,15 +197,15 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
|
|
||||||
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
|
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
|
||||||
|
|
||||||
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
explicit FP32Vec8(const float* ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
||||||
|
|
||||||
explicit FP32Vec8(__m256 data) : reg(data) {}
|
explicit FP32Vec8(__m256 data) : reg(data) {}
|
||||||
|
|
||||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
|
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}
|
||||||
|
|
||||||
explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {}
|
explicit FP32Vec8(const FP16Vec8& v) : reg(_mm256_cvtph_ps(v.reg)) {}
|
||||||
|
|
||||||
explicit FP32Vec8(const BF16Vec8 &v)
|
explicit FP32Vec8(const BF16Vec8& v)
|
||||||
: reg(_mm256_castsi256_ps(
|
: reg(_mm256_castsi256_ps(
|
||||||
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
|
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
|
||||||
|
|
||||||
@ -212,7 +213,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
AliasReg ar;
|
AliasReg ar;
|
||||||
ar.reg = reg;
|
ar.reg = reg;
|
||||||
float result = 0;
|
float result = 0;
|
||||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
unroll_loop<int, VEC_ELEM_NUM>(
|
||||||
|
[&result, &ar](int i) { result += ar.values[i]; });
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -244,27 +246,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
erf(ar.values[1]), erf(ar.values[0])));
|
erf(ar.values[1]), erf(ar.values[0])));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||||
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
|
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||||
return FP32Vec8(_mm256_add_ps(reg, b.reg));
|
return FP32Vec8(_mm256_add_ps(reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||||
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
|
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||||
return FP32Vec8(_mm256_div_ps(reg, b.reg));
|
return FP32Vec8(_mm256_div_ps(reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
|
void save(float* ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __AVX512F__
|
#ifdef __AVX512F__
|
||||||
struct INT32Vec16: public Vec<INT32Vec16> {
|
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||||
constexpr static int VEC_ELEM_NUM = 16;
|
constexpr static int VEC_ELEM_NUM = 16;
|
||||||
union AliasReg {
|
union AliasReg {
|
||||||
__m512i reg;
|
__m512i reg;
|
||||||
@ -273,11 +275,10 @@ struct INT32Vec16: public Vec<INT32Vec16> {
|
|||||||
|
|
||||||
__m512i reg;
|
__m512i reg;
|
||||||
|
|
||||||
explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {}
|
explicit INT32Vec16(const void* data_ptr)
|
||||||
|
: reg(_mm512_loadu_epi32(data_ptr)) {}
|
||||||
|
|
||||||
void save(int32_t* ptr) const {
|
void save(int32_t* ptr) const { _mm512_storeu_epi32(ptr, reg); }
|
||||||
_mm512_storeu_epi32(ptr, reg);
|
|
||||||
}
|
|
||||||
|
|
||||||
void save(int32_t* ptr, const int elem_num) const {
|
void save(int32_t* ptr, const int elem_num) const {
|
||||||
constexpr uint32_t M = 0xFFFFFFFF;
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
@ -301,11 +302,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
|
|
||||||
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec4 &data)
|
explicit FP32Vec16(const FP32Vec4& data)
|
||||||
: reg((__m512)_mm512_inserti32x4(
|
: reg((__m512)_mm512_inserti32x4(
|
||||||
_mm512_inserti32x4(
|
_mm512_inserti32x4(
|
||||||
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
|
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
|
||||||
@ -313,36 +314,37 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
(__m128i)data.reg, 2),
|
(__m128i)data.reg, 2),
|
||||||
(__m128i)data.reg, 3)) {}
|
(__m128i)data.reg, 3)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec8 &data)
|
explicit FP32Vec16(const FP32Vec8& data)
|
||||||
: reg((__m512)_mm512_inserti32x8(
|
: reg((__m512)_mm512_inserti32x8(
|
||||||
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
|
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const BF16Vec16 &v)
|
explicit FP32Vec16(const BF16Vec16& v)
|
||||||
: reg(_mm512_castsi512_ps(
|
: reg(_mm512_castsi512_ps(
|
||||||
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
explicit FP32Vec16(const FP16Vec16& v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const INT32Vec16 &v)
|
explicit FP32Vec16(const INT32Vec16& v)
|
||||||
: reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {}
|
: reg(_mm512_cvt_roundepi32_ps(
|
||||||
|
v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||||
|
|
||||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
|
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(_mm512_add_ps(reg, b.reg));
|
return FP32Vec16(_mm512_add_ps(reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
|
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -370,9 +372,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg));
|
return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 abs() const {
|
FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); }
|
||||||
return FP32Vec16(_mm512_abs_ps(reg));
|
|
||||||
}
|
|
||||||
|
|
||||||
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
||||||
|
|
||||||
@ -380,14 +380,15 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
|
|
||||||
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
|
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
|
||||||
|
|
||||||
template <int group_size> float reduce_sub_sum(int idx) {
|
template <int group_size>
|
||||||
|
float reduce_sub_sum(int idx) {
|
||||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||||
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
|
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
|
||||||
return _mm512_mask_reduce_add_ps(mask, reg);
|
return _mm512_mask_reduce_add_ps(mask, reg);
|
||||||
}
|
}
|
||||||
|
|
||||||
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||||
|
|
||||||
void save(float* ptr, const int elem_num) const {
|
void save(float* ptr, const int elem_num) const {
|
||||||
constexpr uint32_t M = 0xFFFFFFFF;
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
@ -407,32 +408,30 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
__m256 reg_low;
|
__m256 reg_low;
|
||||||
__m256 reg_high;
|
__m256 reg_high;
|
||||||
|
|
||||||
explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
|
explicit FP32Vec16(float v)
|
||||||
reg_high(_mm256_set1_ps(v)) {}
|
: reg_low(_mm256_set1_ps(v)), reg_high(_mm256_set1_ps(v)) {}
|
||||||
|
|
||||||
explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
|
explicit FP32Vec16()
|
||||||
reg_high(_mm256_set1_ps(0.0)) {}
|
: reg_low(_mm256_set1_ps(0.0)), reg_high(_mm256_set1_ps(0.0)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
|
explicit FP32Vec16(const float* ptr)
|
||||||
reg_high(_mm256_loadu_ps(ptr + 8)) {}
|
: reg_low(_mm256_loadu_ps(ptr)), reg_high(_mm256_loadu_ps(ptr + 8)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
|
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
|
explicit FP32Vec16(const FP32Vec16& data)
|
||||||
reg_high(data.reg_high) {}
|
: reg_low(data.reg_low), reg_high(data.reg_high) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec4 &data)
|
explicit FP32Vec16(const FP32Vec4& data)
|
||||||
: reg_low((__m256)_mm256_inserti128_si256(
|
: reg_low((__m256)_mm256_inserti128_si256(
|
||||||
_mm256_castsi128_si256((__m128i)data.reg),
|
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
|
||||||
(__m128i)data.reg, 1)),
|
|
||||||
reg_high((__m256)_mm256_inserti128_si256(
|
reg_high((__m256)_mm256_inserti128_si256(
|
||||||
_mm256_castsi128_si256((__m128i)data.reg),
|
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)) {}
|
||||||
(__m128i)data.reg, 1)) {}
|
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec8 &data)
|
explicit FP32Vec16(const FP32Vec8& data)
|
||||||
: reg_low(data.reg), reg_high(data.reg) {}
|
: reg_low(data.reg), reg_high(data.reg) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP16Vec16 &v) {
|
explicit FP32Vec16(const FP16Vec16& v) {
|
||||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||||
|
|
||||||
@ -440,9 +439,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
reg_high = _mm256_cvtph_ps(high);
|
reg_high = _mm256_cvtph_ps(high);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const BF16Vec16 &v) {
|
explicit FP32Vec16(const BF16Vec16& v) {
|
||||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||||
|
|
||||||
@ -456,24 +455,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
reg_high = _mm256_castsi256_ps(v_high_shifted);
|
reg_high = _mm256_castsi256_ps(v_high_shifted);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||||
|
|
||||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
|
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
|
||||||
_mm256_mul_ps(reg_high, b.reg_high));
|
_mm256_mul_ps(reg_high, b.reg_high));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
|
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
|
||||||
_mm256_add_ps(reg_high, b.reg_high));
|
_mm256_add_ps(reg_high, b.reg_high));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
|
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
|
||||||
_mm256_sub_ps(reg_high, b.reg_high));
|
_mm256_sub_ps(reg_high, b.reg_high));
|
||||||
}
|
}
|
||||||
|
|
||||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
|
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
|
||||||
_mm256_div_ps(reg_high, b.reg_high));
|
_mm256_div_ps(reg_high, b.reg_high));
|
||||||
}
|
}
|
||||||
@ -484,7 +483,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
return low.reduce_sum() + high.reduce_sum();
|
return low.reduce_sum() + high.reduce_sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int group_size> float reduce_sub_sum(int idx) {
|
template <int group_size>
|
||||||
|
float reduce_sub_sum(int idx) {
|
||||||
float sum = 0.0;
|
float sum = 0.0;
|
||||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||||
@ -507,7 +507,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
void save(float *ptr) const {
|
void save(float* ptr) const {
|
||||||
_mm256_storeu_ps(ptr, reg_low);
|
_mm256_storeu_ps(ptr, reg_low);
|
||||||
_mm256_storeu_ps(ptr + 8, reg_high);
|
_mm256_storeu_ps(ptr + 8, reg_high);
|
||||||
}
|
}
|
||||||
@ -515,7 +515,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __AVX512F__
|
#ifdef __AVX512F__
|
||||||
struct INT8Vec16: public Vec<INT8Vec16> {
|
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||||
constexpr static int VEC_ELEM_NUM = 16;
|
constexpr static int VEC_ELEM_NUM = 16;
|
||||||
union AliasReg {
|
union AliasReg {
|
||||||
__m128i reg;
|
__m128i reg;
|
||||||
@ -524,13 +524,11 @@ struct INT8Vec16: public Vec<INT8Vec16> {
|
|||||||
|
|
||||||
__m128i reg;
|
__m128i reg;
|
||||||
|
|
||||||
explicit INT8Vec16(const FP32Vec16& vec) : reg(
|
explicit INT8Vec16(const FP32Vec16& vec)
|
||||||
_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
: reg(_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(
|
||||||
) {}
|
vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))) {}
|
||||||
|
|
||||||
void save(int8_t* ptr) const {
|
void save(int8_t* ptr) const { _mm_storeu_epi8(ptr, reg); }
|
||||||
_mm_storeu_epi8(ptr, reg);
|
|
||||||
}
|
|
||||||
|
|
||||||
void save(int8_t* ptr, const int elem_num) const {
|
void save(int8_t* ptr, const int elem_num) const {
|
||||||
constexpr uint32_t M = 0xFFFFFFFF;
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
@ -540,71 +538,92 @@ struct INT8Vec16: public Vec<INT8Vec16> {
|
|||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T> struct VecType { using vec_type = void; };
|
template <typename T>
|
||||||
|
struct VecType {
|
||||||
|
using vec_type = void;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
template <typename T>
|
||||||
|
using vec_t = typename VecType<T>::vec_type;
|
||||||
|
|
||||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
template <>
|
||||||
|
struct VecType<float> {
|
||||||
|
using vec_type = FP32Vec8;
|
||||||
|
};
|
||||||
|
|
||||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
|
template <>
|
||||||
|
struct VecType<c10::Half> {
|
||||||
|
using vec_type = FP16Vec8;
|
||||||
|
};
|
||||||
|
|
||||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
template <>
|
||||||
|
struct VecType<c10::BFloat16> {
|
||||||
|
using vec_type = BF16Vec8;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
template <typename T>
|
||||||
|
void storeFP32(float v, T* ptr) {
|
||||||
|
*ptr = v;
|
||||||
|
}
|
||||||
|
|
||||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||||
acc = acc + a * b;
|
acc = acc + a * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
template <>
|
||||||
*reinterpret_cast<unsigned short *>(ptr) =
|
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
|
||||||
|
*reinterpret_cast<unsigned short*>(ptr) =
|
||||||
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline FP16Vec8::FP16Vec8(const FP32Vec8 &v)
|
inline FP16Vec8::FP16Vec8(const FP32Vec8& v)
|
||||||
: reg(_mm256_cvtps_ph(v.reg,
|
: reg(_mm256_cvtps_ph(v.reg,
|
||||||
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||||
|
|
||||||
#ifdef __AVX512F__
|
#ifdef __AVX512F__
|
||||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||||
: reg(_mm512_cvtps_ph(v.reg,
|
: reg(_mm512_cvtps_ph(v.reg,
|
||||||
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||||
#else
|
#else
|
||||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||||
: reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
: reg(_mm256_insertf128_si256(
|
||||||
|
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
|
||||||
|
FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __AVX512BF16__
|
#ifdef __AVX512BF16__
|
||||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
template <>
|
||||||
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
|
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||||
|
*reinterpret_cast<__bfloat16*>(ptr) = _mm_cvtness_sbh(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||||
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
|
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
|
||||||
|
|
||||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||||
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
|
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
|
||||||
|
|
||||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
|
||||||
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
|
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
template <>
|
||||||
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||||
reinterpret_cast<c10::BFloat16 *>(&v);
|
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||||
|
reinterpret_cast<c10::BFloat16*>(&v);
|
||||||
*ptr = *(v_ptr + 1);
|
*ptr = *(v_ptr + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __AVX512F__
|
#ifdef __AVX512F__
|
||||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||||
: reg(_mm256_cvtepi32_epi16(
|
: reg(_mm256_cvtepi32_epi16(
|
||||||
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
|
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
|
||||||
|
|
||||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||||
: reg(_mm512_cvtepi32_epi16(
|
: reg(_mm512_cvtepi32_epi16(
|
||||||
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
|
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
|
||||||
#else
|
#else
|
||||||
namespace{
|
namespace {
|
||||||
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
||||||
__m256i ai = _mm256_castps_si256(a);
|
__m256i ai = _mm256_castps_si256(a);
|
||||||
ai = _mm256_srli_epi32(ai, 16);
|
ai = _mm256_srli_epi32(ai, 16);
|
||||||
@ -612,20 +631,20 @@ __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
|||||||
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
|
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
|
||||||
return _mm256_extracti128_si256(ai, 0);
|
return _mm256_extracti128_si256(ai, 0);
|
||||||
}
|
}
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||||
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
|
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
|
||||||
|
|
||||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||||
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
|
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
|
||||||
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
|
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
|
||||||
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
|
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
|
||||||
}
|
}
|
||||||
#endif // __AVX512F__
|
#endif // __AVX512F__
|
||||||
#endif // __AVX512BF16__
|
#endif // __AVX512BF16__
|
||||||
|
|
||||||
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||||
|
|
||||||
}; // namespace vec_op
|
}; // namespace vec_op
|
||||||
|
|
||||||
|
@ -27,8 +27,7 @@
|
|||||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||||
int max_shared_mem_per_block_opt_in = 0;
|
int max_shared_mem_per_block_opt_in = 0;
|
||||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||||
device);
|
|
||||||
return max_shared_mem_per_block_opt_in;
|
return max_shared_mem_per_block_opt_in;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,10 +25,12 @@ Check out the [building from source](#build-from-source) documentation for detai
|
|||||||
```bash
|
```bash
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
# linting and formatting
|
# Linting, formatting and static type checking
|
||||||
bash format.sh
|
pre-commit install
|
||||||
# Static type checking
|
|
||||||
mypy
|
# You can manually run pre-commit with
|
||||||
|
pre-commit run --all-files
|
||||||
|
|
||||||
# Unit tests
|
# Unit tests
|
||||||
pytest tests/
|
pytest tests/
|
||||||
```
|
```
|
||||||
@ -88,7 +90,8 @@ If the PR spans more than one category, please include all relevant prefixes.
|
|||||||
The PR needs to meet the following code quality standards:
|
The PR needs to meet the following code quality standards:
|
||||||
|
|
||||||
- We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
- We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||||
- Pass all linter checks. Please use <gh-file:format.sh> to format your code.
|
- Pass all linter checks. Please use `pre-commit` to format your code. See
|
||||||
|
<https://pre-commit.com/#usage> if `pre-commit` is new to you.
|
||||||
- The code needs to be well-documented to ensure future contributors can easily
|
- The code needs to be well-documented to ensure future contributors can easily
|
||||||
understand the code.
|
understand the code.
|
||||||
- Include sufficient tests to ensure the project stays correct and robust. This
|
- Include sufficient tests to ensure the project stays correct and robust. This
|
||||||
|
321
format.sh
321
format.sh
@ -1,321 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
# YAPF formatter, adapted from ray and skypilot.
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# # Do work and commit your work.
|
|
||||||
|
|
||||||
# # Format files that differ from origin/main.
|
|
||||||
# bash format.sh
|
|
||||||
|
|
||||||
# # Commit changed files with message 'Run yapf and ruff'
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
|
|
||||||
# You are encouraged to run this locally before pushing changes for review.
|
|
||||||
|
|
||||||
# Cause the script to exit if a single command fails
|
|
||||||
set -eo pipefail
|
|
||||||
|
|
||||||
# this stops git rev-parse from failing if we run this from the .git directory
|
|
||||||
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
|
|
||||||
ROOT="$(git rev-parse --show-toplevel)"
|
|
||||||
builtin cd "$ROOT" || exit 1
|
|
||||||
|
|
||||||
check_command() {
|
|
||||||
if ! command -v "$1" &> /dev/null; then
|
|
||||||
echo "❓❓$1 is not installed, please run \`pip install -r requirements-lint.txt\`"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
check_command yapf
|
|
||||||
check_command ruff
|
|
||||||
check_command mypy
|
|
||||||
check_command codespell
|
|
||||||
check_command isort
|
|
||||||
check_command clang-format
|
|
||||||
|
|
||||||
YAPF_VERSION=$(yapf --version | awk '{print $2}')
|
|
||||||
RUFF_VERSION=$(ruff --version | awk '{print $2}')
|
|
||||||
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
|
||||||
CODESPELL_VERSION=$(codespell --version)
|
|
||||||
ISORT_VERSION=$(isort --vn)
|
|
||||||
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
|
|
||||||
PYMARKDOWNLNT_VERSION=$(pymarkdownlnt version | awk '{print $1}')
|
|
||||||
|
|
||||||
# # params: tool name, tool version, required version
|
|
||||||
tool_version_check() {
|
|
||||||
expected=$(grep "$1" requirements-lint.txt | cut -d'=' -f3)
|
|
||||||
if [[ "$2" != "$expected" ]]; then
|
|
||||||
echo "❓❓Wrong $1 version installed: $expected is required, not $2."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_version_check "yapf" "$YAPF_VERSION"
|
|
||||||
tool_version_check "ruff" "$RUFF_VERSION"
|
|
||||||
tool_version_check "mypy" "$MYPY_VERSION"
|
|
||||||
tool_version_check "isort" "$ISORT_VERSION"
|
|
||||||
tool_version_check "codespell" "$CODESPELL_VERSION"
|
|
||||||
tool_version_check "clang-format" "$CLANGFORMAT_VERSION"
|
|
||||||
tool_version_check "pymarkdownlnt" "$PYMARKDOWNLNT_VERSION"
|
|
||||||
|
|
||||||
YAPF_FLAGS=(
|
|
||||||
'--recursive'
|
|
||||||
'--parallel'
|
|
||||||
)
|
|
||||||
|
|
||||||
YAPF_EXCLUDES=(
|
|
||||||
'--exclude' 'build/**'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format specified files
|
|
||||||
format() {
|
|
||||||
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Format files that differ from main branch. Ignores dirs that are not slated
|
|
||||||
# for autoformat yet.
|
|
||||||
format_changed() {
|
|
||||||
# The `if` guard ensures that the list of filenames is not empty, which
|
|
||||||
# could cause yapf to receive 0 positional arguments, making it hang
|
|
||||||
# waiting for STDIN.
|
|
||||||
#
|
|
||||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
|
||||||
# exist on both branches.
|
|
||||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
|
||||||
|
|
||||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
|
||||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
|
|
||||||
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
# Format all files
|
|
||||||
format_all() {
|
|
||||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" .
|
|
||||||
}
|
|
||||||
|
|
||||||
## This flag formats individual files. --files *must* be the first command line
|
|
||||||
## arg to use this option.
|
|
||||||
if [[ "$1" == '--files' ]]; then
|
|
||||||
format "${@:2}"
|
|
||||||
# If `--all` is passed, then any further arguments are ignored and the
|
|
||||||
# entire python directory is formatted.
|
|
||||||
elif [[ "$1" == '--all' ]]; then
|
|
||||||
format_all
|
|
||||||
else
|
|
||||||
# Format only the files that changed in last commit.
|
|
||||||
format_changed
|
|
||||||
fi
|
|
||||||
echo 'vLLM yapf: Done'
|
|
||||||
|
|
||||||
# Run mypy
|
|
||||||
echo 'vLLM mypy:'
|
|
||||||
tools/mypy.sh
|
|
||||||
echo 'vLLM mypy: Done'
|
|
||||||
|
|
||||||
|
|
||||||
# If git diff returns a file that is in the skip list, the file may be checked anyway:
|
|
||||||
# https://github.com/codespell-project/codespell/issues/1915
|
|
||||||
# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem
|
|
||||||
CODESPELL_EXCLUDES=(
|
|
||||||
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**'
|
|
||||||
)
|
|
||||||
|
|
||||||
# check spelling of specified files
|
|
||||||
spell_check() {
|
|
||||||
codespell "$@"
|
|
||||||
}
|
|
||||||
|
|
||||||
spell_check_all(){
|
|
||||||
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Spelling check of files that differ from main branch.
|
|
||||||
spell_check_changed() {
|
|
||||||
# The `if` guard ensures that the list of filenames is not empty, which
|
|
||||||
# could cause ruff to receive 0 positional arguments, making it hang
|
|
||||||
# waiting for STDIN.
|
|
||||||
#
|
|
||||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
|
||||||
# exist on both branches.
|
|
||||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
|
||||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
|
||||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
|
||||||
codespell "${CODESPELL_EXCLUDES[@]}"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Run Codespell
|
|
||||||
## This flag runs spell check of individual files. --files *must* be the first command line
|
|
||||||
## arg to use this option.
|
|
||||||
if [[ "$1" == '--files' ]]; then
|
|
||||||
spell_check "${@:2}"
|
|
||||||
# If `--all` is passed, then any further arguments are ignored and the
|
|
||||||
# entire python directory is linted.
|
|
||||||
elif [[ "$1" == '--all' ]]; then
|
|
||||||
spell_check_all
|
|
||||||
else
|
|
||||||
# Check spelling only of the files that changed in last commit.
|
|
||||||
spell_check_changed
|
|
||||||
fi
|
|
||||||
echo 'vLLM codespell: Done'
|
|
||||||
|
|
||||||
|
|
||||||
# Lint specified files
|
|
||||||
lint() {
|
|
||||||
ruff check "$@"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Lint files that differ from main branch. Ignores dirs that are not slated
|
|
||||||
# for autolint yet.
|
|
||||||
lint_changed() {
|
|
||||||
# The `if` guard ensures that the list of filenames is not empty, which
|
|
||||||
# could cause ruff to receive 0 positional arguments, making it hang
|
|
||||||
# waiting for STDIN.
|
|
||||||
#
|
|
||||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
|
||||||
# exist on both branches.
|
|
||||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
|
||||||
|
|
||||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
|
||||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
|
||||||
ruff check
|
|
||||||
fi
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
# Run Ruff
|
|
||||||
### This flag lints individual files. --files *must* be the first command line
|
|
||||||
### arg to use this option.
|
|
||||||
if [[ "$1" == '--files' ]]; then
|
|
||||||
lint "${@:2}"
|
|
||||||
# If `--all` is passed, then any further arguments are ignored and the
|
|
||||||
# entire python directory is linted.
|
|
||||||
elif [[ "$1" == '--all' ]]; then
|
|
||||||
lint vllm tests
|
|
||||||
else
|
|
||||||
# Format only the files that changed in last commit.
|
|
||||||
lint_changed
|
|
||||||
fi
|
|
||||||
echo 'vLLM ruff: Done'
|
|
||||||
|
|
||||||
# check spelling of specified files
|
|
||||||
isort_check() {
|
|
||||||
isort "$@"
|
|
||||||
}
|
|
||||||
|
|
||||||
isort_check_all(){
|
|
||||||
isort .
|
|
||||||
}
|
|
||||||
|
|
||||||
# Spelling check of files that differ from main branch.
|
|
||||||
isort_check_changed() {
|
|
||||||
# The `if` guard ensures that the list of filenames is not empty, which
|
|
||||||
# could cause ruff to receive 0 positional arguments, making it hang
|
|
||||||
# waiting for STDIN.
|
|
||||||
#
|
|
||||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
|
||||||
# exist on both branches.
|
|
||||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
|
||||||
|
|
||||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
|
||||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
|
||||||
isort
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Run Isort
|
|
||||||
# This flag runs spell check of individual files. --files *must* be the first command line
|
|
||||||
# arg to use this option.
|
|
||||||
if [[ "$1" == '--files' ]]; then
|
|
||||||
isort_check "${@:2}"
|
|
||||||
# If `--all` is passed, then any further arguments are ignored and the
|
|
||||||
# entire python directory is linted.
|
|
||||||
elif [[ "$1" == '--all' ]]; then
|
|
||||||
isort_check_all
|
|
||||||
else
|
|
||||||
# Check spelling only of the files that changed in last commit.
|
|
||||||
isort_check_changed
|
|
||||||
fi
|
|
||||||
echo 'vLLM isort: Done'
|
|
||||||
|
|
||||||
# Clang-format section
|
|
||||||
# Exclude some files for formatting because they are vendored
|
|
||||||
# NOTE: Keep up to date with .github/workflows/clang-format.yml
|
|
||||||
CLANG_FORMAT_EXCLUDES=(
|
|
||||||
'csrc/moe/topk_softmax_kernels.cu'
|
|
||||||
'csrc/quantization/gguf/ggml-common.h'
|
|
||||||
'csrc/quantization/gguf/dequantize.cuh'
|
|
||||||
'csrc/quantization/gguf/vecdotq.cuh'
|
|
||||||
'csrc/quantization/gguf/mmq.cuh'
|
|
||||||
'csrc/quantization/gguf/mmvq.cuh'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format specified files with clang-format
|
|
||||||
clang_format() {
|
|
||||||
clang-format -i "$@"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Format files that differ from main branch with clang-format.
|
|
||||||
clang_format_changed() {
|
|
||||||
# The `if` guard ensures that the list of filenames is not empty, which
|
|
||||||
# could cause clang-format to receive 0 positional arguments, making it hang
|
|
||||||
# waiting for STDIN.
|
|
||||||
#
|
|
||||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
|
||||||
# exist on both branches.
|
|
||||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
|
||||||
|
|
||||||
# Get the list of changed files, excluding the specified ones
|
|
||||||
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | (grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") || echo -e))
|
|
||||||
if [ -n "$changed_files" ]; then
|
|
||||||
echo "$changed_files" | xargs -P 5 clang-format -i
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Format all files with clang-format
|
|
||||||
clang_format_all() {
|
|
||||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
|
||||||
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
|
|
||||||
| xargs clang-format -i
|
|
||||||
}
|
|
||||||
|
|
||||||
# Run clang-format
|
|
||||||
if [[ "$1" == '--files' ]]; then
|
|
||||||
clang_format "${@:2}"
|
|
||||||
elif [[ "$1" == '--all' ]]; then
|
|
||||||
clang_format_all
|
|
||||||
else
|
|
||||||
clang_format_changed
|
|
||||||
fi
|
|
||||||
echo 'vLLM clang-format: Done'
|
|
||||||
|
|
||||||
echo 'vLLM actionlint:'
|
|
||||||
tools/actionlint.sh -color
|
|
||||||
echo 'vLLM actionlint: Done'
|
|
||||||
|
|
||||||
echo 'vLLM shellcheck:'
|
|
||||||
tools/shellcheck.sh
|
|
||||||
echo 'vLLM shellcheck: Done'
|
|
||||||
|
|
||||||
echo 'excalidraw png check:'
|
|
||||||
tools/png-lint.sh
|
|
||||||
echo 'excalidraw png check: Done'
|
|
||||||
|
|
||||||
if ! git diff --quiet &>/dev/null; then
|
|
||||||
echo
|
|
||||||
echo "🔍🔍There are files changed by the format checker or by you that are not added and committed:"
|
|
||||||
git --no-pager diff --name-only
|
|
||||||
echo "🔍🔍Format checker passed, but please add, commit and push all the files above to include changes made by the format checker."
|
|
||||||
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
echo "✨🎉 Format check passed! Congratulations! 🎉✨"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo 'vLLM doc-lint:'
|
|
||||||
tools/doc-lint.sh
|
|
||||||
echo 'vLLM doc-lint: Done'
|
|
@ -15,6 +15,11 @@ build-backend = "setuptools.build_meta"
|
|||||||
[tool.setuptools_scm]
|
[tool.setuptools_scm]
|
||||||
# version_file = "vllm/_version.py" # currently handled by `setup.py:get_version()`
|
# version_file = "vllm/_version.py" # currently handled by `setup.py:get_version()`
|
||||||
|
|
||||||
|
[tool.yapfignore]
|
||||||
|
ignore_patterns = [
|
||||||
|
"build/**",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
# Allow lines to be as long as 80.
|
# Allow lines to be as long as 80.
|
||||||
line-length = 80
|
line-length = 80
|
||||||
@ -52,6 +57,9 @@ ignore = [
|
|||||||
"B007",
|
"B007",
|
||||||
# f-string format
|
# f-string format
|
||||||
"UP032",
|
"UP032",
|
||||||
|
# Python 3.8 typing
|
||||||
|
"UP006", "UP035",
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
|
@ -1,15 +1,2 @@
|
|||||||
# formatting
|
# formatting
|
||||||
yapf==0.32.0
|
pre-commit==4.0.1
|
||||||
toml==0.10.2
|
|
||||||
tomli==2.0.2
|
|
||||||
ruff==0.6.5
|
|
||||||
codespell==2.3.0
|
|
||||||
isort==5.13.2
|
|
||||||
clang-format==18.1.5
|
|
||||||
pymarkdownlnt==0.9.26
|
|
||||||
|
|
||||||
# type checking
|
|
||||||
mypy==1.11.1
|
|
||||||
types-PyYAML
|
|
||||||
types-requests
|
|
||||||
types-setuptools
|
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
if command -v actionlint &> /dev/null; then
|
|
||||||
actionlint "$@"
|
|
||||||
exit 0
|
|
||||||
elif [ -x ./actionlint ]; then
|
|
||||||
./actionlint "$@"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# download a binary to the current directory - v1.7.3
|
|
||||||
bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/aa0a7be8e566b096e64a5df8ff290ec24fa58fbc/scripts/download-actionlint.bash)
|
|
||||||
./actionlint "$@"
|
|
@ -1,3 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
pymarkdownlnt scan docs -r
|
|
Loading…
x
Reference in New Issue
Block a user