2024-03-04 09:23:34 -08:00
|
|
|
vLLM Paged Attention
|
|
|
|
====================
|
|
|
|
|
|
|
|
- Currently, vLLM utilizes its own implementation of a multi-head query
|
|
|
|
attention kernel (``csrc/attention/attention_kernels.cu``).
|
|
|
|
This kernel is designed to be compatible with
|
|
|
|
vLLM's paged KV caches, where the key and value cache are stored in
|
|
|
|
separate blocks (note that this block concept differs from the GPU
|
|
|
|
thread block. So in a later document, I will refer to vLLM paged
|
|
|
|
attention block as "block", while refer to GPU thread block as
|
|
|
|
"thread block").
|
|
|
|
- To achieve high performance, this kernel relies on a specially
|
|
|
|
designed memory layout and access method, specifically when threads
|
|
|
|
read data from global memory to shared memory. The purpose of this
|
|
|
|
document is to provide a high-level explanation of the kernel
|
|
|
|
implementation step by step, aiding those who wish to learn about the
|
|
|
|
vLLM multi-head query attention kernel. After going through this
|
|
|
|
document, users will likely have a better understanding and feel easier
|
|
|
|
to follow the actual implementation.
|
|
|
|
- Please note that this document may not cover all details, such as how
|
|
|
|
to calculate the correct index for the corresponding data or the dot
|
|
|
|
multiplication implementation. However, after reading this document
|
|
|
|
and becoming familiar with the high-level logic flow, it should be
|
|
|
|
easier for you to read the actual code and understand the details.
|
|
|
|
|
|
|
|
Inputs
|
|
|
|
------
|
|
|
|
|
|
|
|
- The kernel function takes a list of arguments for the current thread
|
|
|
|
to perform its assigned work. The three most important arguments are
|
|
|
|
the input pointers ``q``, ``k_cache``, and ``v_cache``, which point
|
|
|
|
to query, key, and value data on global memory that need to be read
|
|
|
|
and processed. The output pointer ``out`` points to global memory
|
|
|
|
where the result should be written. These four pointers actually
|
|
|
|
refer to multi-dimensional arrays, but each thread only accesses the
|
|
|
|
portion of data assigned to it. I have omitted all other runtime
|
|
|
|
parameters here for simplicity.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
template<
|
|
|
|
typename scalar_t,
|
|
|
|
int HEAD_SIZE,
|
|
|
|
int BLOCK_SIZE,
|
|
|
|
int NUM_THREADS,
|
|
|
|
int PARTITION_SIZE = 0>
|
|
|
|
__device__ void paged_attention_kernel(
|
|
|
|
... // Other side args.
|
|
|
|
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
|
|
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
|
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
|
|
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
|
|
|
... // Other side args.
|
|
|
|
)
|
|
|
|
|
|
|
|
- There are also a list of template arguments above the function
|
|
|
|
signature that are determined during compilation time. ``scalar_t``
|
|
|
|
represents the data type of the query, key, and value data elements,
|
|
|
|
such as FP16. ``HEAD_SIZE`` indicates the number of elements in each
|
|
|
|
head. ``BLOCK_SIZE`` refers to the number of tokens in each block.
|
|
|
|
``NUM_THREADS`` denotes the number of threads in each thread block.
|
|
|
|
``PARTITION_SIZE`` represents the number of tensor parallel GPUs (For
|
|
|
|
simplicity, we assume this is 0 and tensor parallel is disabled).
|
|
|
|
- With these arguments, we need to perform a sequence of preparations.
|
|
|
|
This includes calculating the current head index, block index, and
|
|
|
|
other necessary variables. However, for now, we can ignore these
|
|
|
|
preparations and proceed directly to the actual calculations. It will
|
|
|
|
be easier to understand them once we grasp the entire flow.
|
|
|
|
|
|
|
|
Concepts
|
|
|
|
--------
|
|
|
|
|
|
|
|
- Just before we dive into the calculation flow, I want to describe a
|
|
|
|
few concepts that are needed for later sections. However, you may
|
|
|
|
skip this section and return later if you encounter any confusing
|
|
|
|
terminologies.
|
|
|
|
- **Sequence**: A sequence represents a client request. For example,
|
|
|
|
the data pointed to by ``q`` has a shape of
|
|
|
|
``[num_seqs, num_heads, head_size]``. That represents there are total
|
|
|
|
``num_seqs`` of query sequence data are pointed by ``q``. Since this
|
|
|
|
kernel is a single query attention kernel, each sequence only has one
|
|
|
|
query token. Hence, the ``num_seqs`` equals the total number of tokens
|
|
|
|
that are processed in the batch.
|
|
|
|
- **Context**: The context consists of the generated tokens from the
|
|
|
|
sequence. For instance, ``["What", "is", "your"]`` are the context
|
|
|
|
tokens, and the input query token is ``"name"``. The model might
|
|
|
|
generate the token ``"?"``.
|
|
|
|
- **Vec**: The vec is a list of elements that are fetched and
|
|
|
|
calculated together. For query and key data, the vec size
|
|
|
|
(``VEC_SIZE``) is determined so that each thread group can fetch and
|
|
|
|
calculate 16 bytes of data at a time. For value data, the vec size
|
|
|
|
(``V_VEC_SIZE``) is determined so that each thread can fetch and
|
|
|
|
calculate 16 bytes of data at a time. For example, if the
|
|
|
|
``scalar_t`` is FP16 (2 bytes) and ``THREAD_GROUP_SIZE`` is 2, the
|
|
|
|
``VEC_SIZE`` will be 4, while the ``V_VEC_SIZE`` will be 8.
|
|
|
|
- **Thread group**: The thread group is a small group of
|
|
|
|
threads(\ ``THREAD_GROUP_SIZE``) that fetches and calculates one
|
|
|
|
query token and one key token at a time. Each thread handles only a
|
|
|
|
portion of the token data. The total number of elements processed by
|
|
|
|
one thread group is referred as ``x``. For example, if the thread
|
|
|
|
group contains 2 threads and the head size is 8, then thread 0
|
|
|
|
handles the query and key elements at index 0, 2, 4, 6, while thread
|
|
|
|
1 handles the elements at index 1, 3, 5, 7.
|
|
|
|
- **Block**: The key and value cache data in vLLM are split into
|
|
|
|
blocks. Each block stores data for a fixed number(\ ``BLOCK_SIZE``)
|
|
|
|
of tokens at one head. Each block may contain only a portion of the
|
|
|
|
whole context tokens. For example, if the block size is 16 and the
|
|
|
|
head size is 128, then for one head, one block can store 16 \* 128 =
|
|
|
|
2048 elements.
|
|
|
|
- **Warp**: A warp is a group of 32 threads(\ ``WARP_SIZE``) that
|
|
|
|
execute simultaneously on a stream multiprocessor (SM). In this
|
|
|
|
kernel, each warp processes the calculation between one query token
|
|
|
|
and key tokens of one entire block at a time (it may process multiple
|
|
|
|
blocks in multiple iterations). For example, if there are 4 warps and
|
|
|
|
6 blocks for one context, the assignment would be like warp 0 handles
|
|
|
|
the 0th, 4th blocks, warp 1 handles the 1st, 5th blocks, warp 2
|
|
|
|
handles the 2nd block and warp 3 handles the 3rd block.
|
|
|
|
- **Thread block**: A thread block is a group of
|
|
|
|
threads(\ ``NUM_THREADS``) that can access the same shared memory.
|
|
|
|
Each thread block contains multiple warps(\ ``NUM_WARPS``), and in
|
|
|
|
this kernel, each thread block processes the calculation between one
|
|
|
|
query token and key tokens of a whole context.
|
|
|
|
- **Grid**: A grid is a collection of thread blocks and defines the
|
|
|
|
shape of the collection. In this kernel, the shape is
|
|
|
|
``(num_heads, num_seqs, max_num_partitions)``. Therefore, each thread
|
|
|
|
block only handles the calculation for one head, one sequence, and
|
|
|
|
one partition.
|
|
|
|
|
|
|
|
Query
|
|
|
|
-----
|
|
|
|
|
|
|
|
- This section will introduce how query data is stored in memory and
|
|
|
|
fetched by each thread. As mentioned above, each thread group fetches
|
|
|
|
one query token data, while each thread itself only handles a part of
|
|
|
|
one query token data. Within each warp, every thread group will fetch
|
|
|
|
the same query token data, but will multiply it with different key
|
|
|
|
token data.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
|
|
|
|
|
|
|
.. figure:: ../../assets/kernel/query.png
|
|
|
|
:alt: query
|
|
|
|
:width: 70%
|
|
|
|
:align: center
|
|
|
|
|
|
|
|
Query data of one token at one head
|
|
|
|
|
|
|
|
- Each thread defines its own ``q_ptr`` which points to the assigned
|
|
|
|
query token data on global memory. For example, if ``VEC_SIZE`` is 4
|
|
|
|
and ``HEAD_SIZE`` is 128, the ``q_ptr`` points to data that contains
|
|
|
|
total of 128 elements divided into 128 / 4 = 32 vecs.
|
|
|
|
|
|
|
|
.. figure:: ../../assets/kernel/q_vecs.png
|
|
|
|
:alt: q_vecs
|
|
|
|
:width: 70%
|
|
|
|
:align: center
|
|
|
|
|
|
|
|
``q_vecs`` for one thread group
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
|
|
|
|
|
|
|
- Next, we need to read the global memory data pointed to by ``q_ptr``
|
|
|
|
into shared memory as ``q_vecs``. It is important to note that each
|
|
|
|
vecs is assigned to a different row. For example, if the
|
|
|
|
``THREAD_GROUP_SIZE`` is 2, thread 0 will handle the 0th row vecs,
|
|
|
|
while thread 1 handles the 1st row vecs. By reading the query data in
|
|
|
|
this way, neighboring threads like thread 0 and thread 1 can read
|
|
|
|
neighbor memory, achieving the memory coalescing to improve
|
|
|
|
performance.
|
|
|
|
|
|
|
|
Key
|
|
|
|
---
|
|
|
|
|
|
|
|
- Similar to the "Query" section, this section introduces memory layout
|
|
|
|
and assignment for keys. While each thread group only handle one
|
|
|
|
query token one kernel run, it may handle multiple key tokens across
|
|
|
|
multiple iterations. Meanwhile, each warp will process multiple blocks
|
|
|
|
of key tokens in multiple iterations, ensuring that all context
|
|
|
|
tokens are processed by the entire thread group after the kernel run.
|
|
|
|
In this context, "handle" refers to performing the dot multiplication
|
|
|
|
between query data and key data.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
|
|
|
+ kv_head_idx * kv_head_stride
|
|
|
|
+ physical_block_offset * x;
|
|
|
|
|
|
|
|
- Unlike to ``q_ptr``, ``k_ptr`` in each thread will point to different
|
|
|
|
key token at different iterations. As shown above, that ``k_ptr``
|
|
|
|
points to key token data based on ``k_cache`` at assigned block,
|
|
|
|
assigned head and assigned token.
|
|
|
|
|
|
|
|
.. figure:: ../../assets/kernel/key.png
|
|
|
|
:alt: key
|
|
|
|
:width: 70%
|
|
|
|
:align: center
|
|
|
|
|
|
|
|
Key data of all context tokens at one head
|
|
|
|
|
|
|
|
- The diagram above illustrates the memory layout for key data. It
|
|
|
|
assumes that the ``BLOCK_SIZE`` is 16, ``HEAD_SIZE`` is 128, ``x`` is
|
|
|
|
8, ``THREAD_GROUP_SIZE`` is 2, and there are a total of 4 warps. Each
|
|
|
|
rectangle represents all the elements for one key token at one head,
|
|
|
|
which will be processed by one thread group. The left half shows the
|
|
|
|
total 16 blocks of key token data for warp 0, while the right half
|
|
|
|
represents the remaining key token data for other warps or
|
|
|
|
iterations. Inside each rectangle, there are a total 32 vecs (128
|
|
|
|
elements for one token) that will be processed by 2 threads (one
|
|
|
|
thread group) separately.
|
|
|
|
|
|
|
|
.. figure:: ../../assets/kernel/k_vecs.png
|
|
|
|
:alt: k_vecs
|
|
|
|
:width: 70%
|
|
|
|
:align: center
|
|
|
|
|
|
|
|
``k_vecs`` for one thread
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
|
|
|
|
|
|
|
- Next, we need to read the key token data from ``k_ptr`` and store
|
|
|
|
them on register memory as ``k_vecs``. We use register memory for
|
|
|
|
``k_vecs`` because it will only be accessed by one thread once,
|
|
|
|
whereas ``q_vecs`` will be accessed by multiple threads multiple
|
|
|
|
times. Each ``k_vecs`` will contain multiple vectors for later
|
|
|
|
calculation. Each vec will be set at each inner iteration. The
|
|
|
|
assignment of vecs allows neighboring threads in a warp to read
|
|
|
|
neighboring memory together, which again promotes the memory
|
|
|
|
coalescing. For instance, thread 0 will read vec 0, while thread 1
|
|
|
|
will read vec 1. In the next inner loop, thread 0 will read vec 2,
|
|
|
|
while thread 1 will read vec 3, and so on.
|
|
|
|
- You may still be a little confused about the overall flow. Don't
|
|
|
|
worry, please keep reading the next "QK" section. It will illustrate
|
|
|
|
the query and key calculation flow in a clearer and higher-level
|
|
|
|
manner.
|
|
|
|
|
|
|
|
QK
|
|
|
|
---
|
|
|
|
|
|
|
|
- As shown the pseudo code below, before the entire for loop block, we
|
|
|
|
fetch the query data for one token and store it in ``q_vecs``. Then,
|
|
|
|
in the outer for loop, we iterate through different ``k_ptrs`` that
|
|
|
|
point to different tokens and prepare the ``k_vecs`` in the inner for
|
|
|
|
loop. Finally, we perform the dot multiplication between the
|
|
|
|
``q_vecs`` and each ``k_vecs``.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
q_vecs = ...
|
|
|
|
for ... {
|
|
|
|
k_ptr = ...
|
|
|
|
for ... {
|
|
|
|
k_vecs[i] = ...
|
|
|
|
}
|
|
|
|
...
|
|
|
|
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
|
|
|
}
|
|
|
|
|
|
|
|
- As mentioned before, for each thread, it only fetches part of the
|
|
|
|
query and key token data at a time. However, there will be a cross
|
|
|
|
thread group reduction happen in the ``Qk_dot<>::dot`` . So ``qk``
|
|
|
|
returned here is not just between part of the query and key token dot
|
|
|
|
multiplication, but actually a full result between entire query and
|
|
|
|
key token data.
|
|
|
|
- For example, if the value of ``HEAD_SIZE`` is 128 and
|
|
|
|
``THREAD_GROUP_SIZE`` is 2, each thread's ``k_vecs`` will contain
|
|
|
|
total 64 elements. However, the returned ``qk`` is actually the
|
|
|
|
result of dot multiplication between 128 query elements and 128 key
|
|
|
|
elements. If you want to learn more about the details of the dot
|
|
|
|
multiplication and reduction, you may refer to the implementation of
|
|
|
|
``Qk_dot<>::dot``. However, for the sake of simplicity, I will not
|
|
|
|
cover it in this document.
|
|
|
|
|
|
|
|
Softmax
|
|
|
|
-------
|
|
|
|
|
|
|
|
- Next, we need to calculate the normalized softmax for all ``qk``\ s,
|
|
|
|
as shown above, where each :math:`x` represents a ``qk``. To do this,
|
|
|
|
we must obtain the reduced value of ``qk_max``\ (:math:`m(x)`) and
|
|
|
|
the ``exp_sum``\ (:math:`\ell(x)`) of all ``qk``\ s. The reduction
|
|
|
|
should be performed across the entire thread block, encompassing
|
|
|
|
results between the query token and all context key tokens.
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
:nowrap:
|
|
|
|
|
|
|
|
\begin{gather*}
|
|
|
|
m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\
|
|
|
|
\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}
|
|
|
|
\end{gather*}
|
|
|
|
|
|
|
|
``qk_max`` and ``logits``
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
- Just right after we get the ``qk`` result, we can set the temporary
|
|
|
|
``logits`` result with ``qk`` (In the end, the ``logits`` should
|
|
|
|
store the normalized softmax result). Also we can compare and collect
|
|
|
|
the ``qk_max`` for all ``qk``\ s that are calculated by current
|
|
|
|
thread group.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
if (thread_group_offset == 0) {
|
|
|
|
const bool mask = token_idx >= context_len;
|
|
|
|
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
|
|
|
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
|
|
|
}
|
|
|
|
|
|
|
|
- Please note that the ``logits`` here is on shared memory, so each
|
|
|
|
thread group will set the fields for its own assigned context tokens.
|
|
|
|
Overall, the size of logits should be number of context tokens.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
|
|
|
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (lane == 0) {
|
|
|
|
red_smem[warp_idx] = qk_max;
|
|
|
|
}
|
|
|
|
|
|
|
|
- Then we need to get the reduced ``qk_max`` across each warp. The main
|
|
|
|
idea is to make threads in warp to communicate with each other and
|
|
|
|
get the final max ``qk`` .
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
|
|
|
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
|
|
|
}
|
|
|
|
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
|
|
|
|
|
|
|
- Finally, we can get the reduced ``qk_max`` from whole thread block by
|
|
|
|
compare the ``qk_max`` from all warps in this thread block. Then we
|
|
|
|
need to broadcast the final result to each thread.
|
|
|
|
|
|
|
|
``exp_sum``
|
|
|
|
~~~~~~~~~~~
|
|
|
|
|
|
|
|
- Similar to ``qk_max``, we need to get the reduced sum value from the
|
|
|
|
entire thread block too.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
|
|
|
float val = __expf(logits[i] - qk_max);
|
|
|
|
logits[i] = val;
|
|
|
|
exp_sum += val;
|
|
|
|
}
|
|
|
|
...
|
|
|
|
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
|
|
|
|
|
|
|
- Firstly, sum all exp values from each thread group, and meanwhile,
|
|
|
|
convert each entry of ``logits`` from ``qk`` to ``exp(qk - qk_max)``.
|
|
|
|
Please note, the ``qk_max`` here is already the max ``qk`` across the
|
|
|
|
whole thread block. And then we can do reduction for ``exp_sum``
|
|
|
|
across whole thread block just like the ``qk_max``.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
|
|
|
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
|
|
|
logits[i] *= inv_sum;
|
|
|
|
}
|
|
|
|
|
|
|
|
- Finally, with the reduced ``qk_max`` and ``exp_sum``, we can obtain
|
|
|
|
the final normalized softmax result as ``logits``. This ``logits``
|
|
|
|
variable will be used for dot multiplication with the value data in
|
|
|
|
later steps. Now, it should store the normalized softmax result of
|
|
|
|
``qk`` for all assigned context tokens.
|
|
|
|
|
|
|
|
Value
|
|
|
|
-----
|
|
|
|
|
|
|
|
.. figure:: ../../assets/kernel/value.png
|
|
|
|
:alt: value
|
|
|
|
:width: 70%
|
|
|
|
:align: center
|
|
|
|
|
|
|
|
Value data of all context tokens at one head
|
|
|
|
|
|
|
|
.. figure:: ../../assets/kernel/logits_vec.png
|
|
|
|
:alt: logits_vec
|
|
|
|
:width: 50%
|
|
|
|
:align: center
|
|
|
|
|
|
|
|
``logits_vec`` for one thread
|
|
|
|
|
|
|
|
.. figure:: ../../assets/kernel/v_vec.png
|
|
|
|
:alt: v_vec
|
|
|
|
:width: 70%
|
|
|
|
:align: center
|
|
|
|
|
|
|
|
List of ``v_vec`` for one thread
|
|
|
|
|
|
|
|
- Now we need to retrieve the value data and perform dot multiplication
|
|
|
|
with ``logits``. Unlike query and key, there is no thread group
|
|
|
|
concept for value data. As shown in diagram, different from key token
|
|
|
|
memory layout, elements from the same column correspond to the same
|
|
|
|
value token. For one block of value data, there are ``HEAD_SIZE`` of
|
|
|
|
rows and ``BLOCK_SIZE`` of columns that are split into multiple
|
|
|
|
``v_vecs``.
|
|
|
|
- Each thread always fetches ``V_VEC_SIZE`` elements from the same
|
|
|
|
``V_VEC_SIZE`` of tokens at a time. As a result, a single thread
|
|
|
|
retrieves multiple ``v_vec``\ s from different rows and the same
|
|
|
|
columns through multiple inner iterations. For each ``v_vec``, it
|
|
|
|
needs to be dot multiplied with the corresponding ``logits_vec``,
|
|
|
|
which is also ``V_VEC_SIZE`` elements from ``logits``. Overall, with
|
|
|
|
multiple inner iterations, each warp will process one block of value
|
|
|
|
tokens. And with multiple outer iterations, the whole context value
|
|
|
|
tokens are processd
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
float accs[NUM_ROWS_PER_THREAD];
|
|
|
|
for ... { // Iteration over different blocks.
|
|
|
|
logits_vec = ...
|
|
|
|
for ... { // Iteration over different rows.
|
|
|
|
v_vec = ...
|
|
|
|
...
|
|
|
|
accs[i] += dot(logits_vec, v_vec);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
- As shown in the above pseudo code, in the outer loop, similar to
|
|
|
|
``k_ptr``, ``logits_vec`` iterates over different blocks and reads
|
|
|
|
``V_VEC_SIZE`` elements from ``logits``. In the inner loop, each
|
|
|
|
thread reads ``V_VEC_SIZE`` elements from the same tokens as a
|
|
|
|
``v_vec`` and performs dot multiplication. It is important to note
|
|
|
|
that in each inner iteration, the thread fetches different head
|
|
|
|
position elements for the same tokens. The dot result is then
|
|
|
|
accumulated in ``accs``. Therefore, each entry of ``accs`` is mapped
|
|
|
|
to a head position assigned to the current thread.
|
|
|
|
- For example, if ``BLOCK_SIZE`` is 16 and ``V_VEC_SIZE`` is 8, each
|
|
|
|
thread fetches 8 value elements for 8 tokens at a time. Each element
|
|
|
|
is from different tokens at the same head position. If ``HEAD_SIZE``
|
|
|
|
is 128 and ``WARP_SIZE`` is 32, for each inner loop, a warp needs to
|
|
|
|
fetch ``WARP_SIZE * V_VEC_SIZE = 256`` elements. This means there are
|
|
|
|
a total of 128 \* 16 / 256 = 8 inner iterations for a warp to handle
|
|
|
|
a whole block of value tokens. And each ``accs`` in each thread
|
|
|
|
contains 8 elements that accumulated at 8 different head positions.
|
|
|
|
For the thread 0, the ``accs`` variable will have 8 elements, which
|
2024-03-16 07:06:09 +08:00
|
|
|
are 0th, 32th … 224th elements of a value head that are accumulated
|
2024-03-04 09:23:34 -08:00
|
|
|
from all assigned 8 tokens.
|
|
|
|
|
|
|
|
LV
|
|
|
|
---
|
|
|
|
- Now, we need to perform reduction for ``accs`` within each warp. This
|
|
|
|
process allows each thread to accumulate the ``accs`` for the
|
|
|
|
assigned head positions of all tokens in one block.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
|
|
float acc = accs[i];
|
|
|
|
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
|
|
|
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
|
|
|
}
|
|
|
|
accs[i] = acc;
|
|
|
|
}
|
|
|
|
|
|
|
|
- Next, we perform reduction for ``accs`` across all warps, allowing
|
|
|
|
each thread to have the accumulation of ``accs`` for the assigned
|
|
|
|
head positions of all context tokens. Please note that each ``accs``
|
|
|
|
in every thread only stores the accumulation for a portion of
|
|
|
|
elements of the entire head for all context tokens. However, overall,
|
|
|
|
all results for output have been calculated but are just stored in
|
|
|
|
different thread register memory.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
|
|
|
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
|
|
|
// Upper warps write to shared memory.
|
|
|
|
...
|
|
|
|
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
|
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
|
|
...
|
|
|
|
dst[row_idx] = accs[i];
|
|
|
|
}
|
|
|
|
|
|
|
|
// Lower warps update the output.
|
|
|
|
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
|
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
|
|
...
|
|
|
|
accs[i] += src[row_idx];
|
|
|
|
}
|
|
|
|
|
|
|
|
// Write out the accs.
|
|
|
|
}
|
|
|
|
|
|
|
|
Output
|
|
|
|
------
|
|
|
|
|
|
|
|
- Now we can write all of calculated result from local register memory
|
|
|
|
to final output global memory.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
|
|
|
+ head_idx * max_num_partitions * HEAD_SIZE
|
|
|
|
+ partition_idx * HEAD_SIZE;
|
|
|
|
|
|
|
|
- First, we need to define the ``out_ptr`` variable, which points to
|
|
|
|
the start address of the assigned sequence and assigned head.
|
|
|
|
|
|
|
|
.. code:: cpp
|
|
|
|
|
|
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
|
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
|
|
from_float(*(out_ptr + row_idx), accs[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
- Finally, we need to iterate over different assigned head positions
|
|
|
|
and write out the corresponding accumulated result based on the
|
|
|
|
``out_ptr``.
|