Support FP8-E5M2 KV Cache (#2279)
Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
7d648418b8
commit
9090bf02e7
@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
|
|||||||
trust_remote_code=args.trust_remote_code,
|
trust_remote_code=args.trust_remote_code,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
enforce_eager=args.enforce_eager,
|
enforce_eager=args.enforce_eager,
|
||||||
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@ -117,6 +118,13 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--enforce-eager',
|
parser.add_argument('--enforce-eager',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='enforce eager mode and disable CUDA graph')
|
help='enforce eager mode and disable CUDA graph')
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=['auto', 'fp8_e5m2'],
|
||||||
|
default='auto',
|
||||||
|
help=
|
||||||
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--profile',
|
'--profile',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
|
@ -71,6 +71,7 @@ def run_vllm(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_model_len: Optional[int],
|
max_model_len: Optional[int],
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
|
kv_cache_dtype: str,
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -83,6 +84,7 @@ def run_vllm(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
@ -206,7 +208,8 @@ def main(args: argparse.Namespace):
|
|||||||
args.quantization, args.tensor_parallel_size,
|
args.quantization, args.tensor_parallel_size,
|
||||||
args.seed, args.n, args.use_beam_search,
|
args.seed, args.n, args.use_beam_search,
|
||||||
args.trust_remote_code, args.dtype,
|
args.trust_remote_code, args.dtype,
|
||||||
args.max_model_len, args.enforce_eager)
|
args.max_model_len, args.enforce_eager,
|
||||||
|
args.kv_cache_dtype)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@ -284,6 +287,13 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--enforce-eager",
|
parser.add_argument("--enforce-eager",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="enforce eager execution")
|
help="enforce eager execution")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8_e5m2"],
|
||||||
|
default="auto",
|
||||||
|
help=
|
||||||
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
from typing import Optional
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
NUM_BLOCKS = 1024
|
NUM_BLOCKS = 1024
|
||||||
@ -23,6 +25,7 @@ def main(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
do_profile: bool,
|
do_profile: bool,
|
||||||
|
kv_cache_dtype: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
@ -59,15 +62,10 @@ def main(
|
|||||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
# Create the KV cache.
|
# Create the KV cache.
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
key_caches, value_caches = create_kv_caches_with_random(
|
||||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
|
||||||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
|
dtype)
|
||||||
key_cache.uniform_(-scale, scale)
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
|
|
||||||
value_cache = torch.empty(size=value_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device="cuda")
|
|
||||||
value_cache.uniform_(-scale, scale)
|
|
||||||
|
|
||||||
# Prepare for the paged attention kernel.
|
# Prepare for the paged attention kernel.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
@ -106,6 +104,7 @@ def main(
|
|||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
@ -123,6 +122,7 @@ def main(
|
|||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid version: {version}")
|
raise ValueError(f"Invalid version: {version}")
|
||||||
@ -168,16 +168,18 @@ if __name__ == '__main__':
|
|||||||
default="half")
|
default="half")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--profile", action="store_true")
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8_e5m2"],
|
||||||
|
default="auto",
|
||||||
|
help=
|
||||||
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
if args.num_query_heads % args.num_kv_heads != 0:
|
if args.num_query_heads % args.num_kv_heads != 0:
|
||||||
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
||||||
dtype_to_torch_dtype = {
|
|
||||||
"half": torch.half,
|
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
"float": torch.float,
|
|
||||||
}
|
|
||||||
main(
|
main(
|
||||||
version=args.version,
|
version=args.version,
|
||||||
num_seqs=args.batch_size,
|
num_seqs=args.batch_size,
|
||||||
@ -187,7 +189,8 @@ if __name__ == '__main__':
|
|||||||
head_size=args.head_size,
|
head_size=args.head_size,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
use_alibi=args.use_alibi,
|
use_alibi=args.use_alibi,
|
||||||
dtype=dtype_to_torch_dtype[args.dtype],
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
do_profile=args.profile,
|
do_profile=args.profile,
|
||||||
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
@ -4,3 +4,4 @@
|
|||||||
#include "dtype_float16.cuh"
|
#include "dtype_float16.cuh"
|
||||||
#include "dtype_float32.cuh"
|
#include "dtype_float32.cuh"
|
||||||
#include "dtype_bfloat16.cuh"
|
#include "dtype_bfloat16.cuh"
|
||||||
|
#include "dtype_fp8_e5m2.cuh"
|
||||||
|
@ -25,6 +25,7 @@
|
|||||||
|
|
||||||
#include "attention_dtypes.h"
|
#include "attention_dtypes.h"
|
||||||
#include "attention_utils.cuh"
|
#include "attention_utils.cuh"
|
||||||
|
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
@ -79,17 +80,19 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
|||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template<
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
|
typename cache_t,
|
||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||||
__device__ void paged_attention_kernel(
|
__device__ void paged_attention_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
@ -145,6 +148,9 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||||
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||||
@ -176,7 +182,7 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
||||||
// Each thread group fetches x elements from the key at a time.
|
// Each thread group fetches x elements from the key at a time.
|
||||||
constexpr int x = 16 / sizeof(scalar_t);
|
constexpr int x = 16 / sizeof(cache_t);
|
||||||
float qk_max = -FLT_MAX;
|
float qk_max = -FLT_MAX;
|
||||||
|
|
||||||
// Iterate over the key blocks.
|
// Iterate over the key blocks.
|
||||||
@ -202,13 +208,23 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||||
+ kv_head_idx * kv_head_stride
|
+ kv_head_idx * kv_head_stride
|
||||||
+ physical_block_offset * x;
|
+ physical_block_offset * x;
|
||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
// Vector conversion from Quant_vec to K_vec.
|
||||||
|
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute dot product.
|
// Compute dot product.
|
||||||
@ -282,6 +298,9 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||||
|
#endif
|
||||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||||
|
|
||||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||||
@ -307,14 +326,25 @@ __device__ void paged_attention_kernel(
|
|||||||
L_vec logits_vec;
|
L_vec logits_vec;
|
||||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||||
|
|
||||||
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||||
+ kv_head_idx * kv_head_stride;
|
+ kv_head_idx * kv_head_stride;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
if (row_idx < HEAD_SIZE) {
|
if (row_idx < HEAD_SIZE) {
|
||||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
V_vec v_vec;
|
||||||
|
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
|
// Vector conversion from V_quant_vec to V_vec.
|
||||||
|
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
|
}
|
||||||
if (block_idx == num_context_blocks - 1) {
|
if (block_idx == num_context_blocks - 1) {
|
||||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||||
// we should explicitly zero out the values since they may contain NaNs.
|
// we should explicitly zero out the values since they may contain NaNs.
|
||||||
@ -395,14 +425,16 @@ __device__ void paged_attention_kernel(
|
|||||||
// Grid: (num_heads, num_seqs, 1).
|
// Grid: (num_heads, num_seqs, 1).
|
||||||
template<
|
template<
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
|
typename cache_t,
|
||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS>
|
int NUM_THREADS,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE>
|
||||||
__global__ void paged_attention_v1_kernel(
|
__global__ void paged_attention_v1_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
@ -412,7 +444,7 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
const int q_stride,
|
const int q_stride,
|
||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride) {
|
||||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||||
@ -421,17 +453,19 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template<
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
|
typename cache_t,
|
||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int PARTITION_SIZE>
|
int PARTITION_SIZE>
|
||||||
__global__ void paged_attention_v2_kernel(
|
__global__ void paged_attention_v2_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
@ -441,7 +475,7 @@ __global__ void paged_attention_v2_kernel(
|
|||||||
const int q_stride,
|
const int q_stride,
|
||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride) {
|
||||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||||
q_stride, kv_block_stride, kv_head_stride);
|
q_stride, kv_block_stride, kv_head_stride);
|
||||||
@ -550,10 +584,10 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||||
((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
|
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
shared_mem_size); \
|
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
||||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
query_ptr, \
|
query_ptr, \
|
||||||
key_cache_ptr, \
|
key_cache_ptr, \
|
||||||
@ -571,7 +605,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template<
|
template<
|
||||||
typename T,
|
typename T,
|
||||||
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int NUM_THREADS = 128>
|
int NUM_THREADS = 128>
|
||||||
void paged_attention_v1_launcher(
|
void paged_attention_v1_launcher(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
@ -602,8 +638,8 @@ void paged_attention_v1_launcher(
|
|||||||
|
|
||||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
@ -647,35 +683,35 @@ void paged_attention_v1_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||||
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||||
out, \
|
out, \
|
||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
num_kv_heads, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
max_context_len, \
|
max_context_len, \
|
||||||
alibi_slopes);
|
alibi_slopes);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V1_LAUNCHER(T, 8); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V1_LAUNCHER(T, 16); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V1_LAUNCHER(T, 32); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(
|
||||||
@ -689,20 +725,36 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
const std::string& kv_cache_dtype) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
|
if (kv_cache_dtype == "auto") {
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||||
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||||
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
|
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
exp_sums_ptr, \
|
exp_sums_ptr, \
|
||||||
max_logits_ptr, \
|
max_logits_ptr, \
|
||||||
@ -730,7 +782,9 @@ void paged_attention_v1(
|
|||||||
|
|
||||||
template<
|
template<
|
||||||
typename T,
|
typename T,
|
||||||
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int NUM_THREADS = 128,
|
int NUM_THREADS = 128,
|
||||||
int PARTITION_SIZE = 512>
|
int PARTITION_SIZE = 512>
|
||||||
void paged_attention_v2_launcher(
|
void paged_attention_v2_launcher(
|
||||||
@ -768,8 +822,8 @@ void paged_attention_v2_launcher(
|
|||||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
@ -816,38 +870,38 @@ void paged_attention_v2_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||||
paged_attention_v2_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||||
out, \
|
out, \
|
||||||
exp_sums, \
|
exp_sums, \
|
||||||
max_logits, \
|
max_logits, \
|
||||||
tmp_out, \
|
tmp_out, \
|
||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
num_kv_heads, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
max_context_len, \
|
max_context_len, \
|
||||||
alibi_slopes);
|
alibi_slopes);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V2_LAUNCHER(T, 8); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V2_LAUNCHER(T, 16); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V2_LAUNCHER(T, 32); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
@ -864,15 +918,30 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
const std::string& kv_cache_dtype) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
|
if (kv_cache_dtype == "auto") {
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||||
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "attention_generic.cuh"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
// fp8 vector types for quantization of kv cache
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 1> {
|
||||||
|
using Type = uint8_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 2> {
|
||||||
|
using Type = uint16_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 4> {
|
||||||
|
using Type = uint32_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 8> {
|
||||||
|
using Type = uint2;
|
||||||
|
};
|
||||||
|
#endif // ENABLE_FP8_E5M2
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -20,7 +20,8 @@ void reshape_and_cache(
|
|||||||
torch::Tensor& value,
|
torch::Tensor& value,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping);
|
torch::Tensor& slot_mapping,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
void gather_cached_kv(
|
void gather_cached_kv(
|
||||||
torch::Tensor& key,
|
torch::Tensor& key,
|
||||||
@ -28,3 +29,8 @@ void gather_cached_kv(
|
|||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping);
|
torch::Tensor& slot_mapping);
|
||||||
|
|
||||||
|
// Just for unittest
|
||||||
|
void convert_fp8_e5m2(
|
||||||
|
torch::Tensor& src_cache,
|
||||||
|
torch::Tensor& dst_cache);
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -131,7 +132,7 @@ void copy_blocks(
|
|||||||
dim3 block(std::min(1024, numel_per_block));
|
dim3 block(std::min(1024, numel_per_block));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
@ -143,12 +144,12 @@ void copy_blocks(
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int key_stride,
|
const int key_stride,
|
||||||
const int value_stride,
|
const int value_stride,
|
||||||
@ -185,19 +186,45 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
+ head_idx * head_size * block_size
|
+ head_idx * head_size * block_size
|
||||||
+ head_offset * block_size
|
+ head_offset * block_size
|
||||||
+ block_offset;
|
+ block_offset;
|
||||||
key_cache[tgt_key_idx] = key[src_key_idx];
|
scalar_t tgt_key = key[src_key_idx];
|
||||||
value_cache[tgt_value_idx] = value[src_value_idx];
|
scalar_t tgt_value = value[src_value_idx];
|
||||||
|
if constexpr (is_fp8_e5m2_kv_cache) {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||||
|
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
key_cache[tgt_key_idx] = tgt_key;
|
||||||
|
value_cache[tgt_value_idx] = tgt_value;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||||
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||||
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
|
slot_mapping.data_ptr<int64_t>(), \
|
||||||
|
key_stride, \
|
||||||
|
value_stride, \
|
||||||
|
num_heads, \
|
||||||
|
head_size, \
|
||||||
|
block_size, \
|
||||||
|
x);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping) // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
|
const std::string& kv_cache_dtype)
|
||||||
{
|
{
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
@ -212,23 +239,25 @@ void reshape_and_cache(
|
|||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
if (kv_cache_dtype == "auto") {
|
||||||
key.scalar_type(),
|
if (key.dtype() == at::ScalarType::Float) {
|
||||||
"reshape_and_cache_kernel",
|
CALL_RESHAPE_AND_CACHE(float, float, false);
|
||||||
[&] {
|
} else if (key.dtype() == at::ScalarType::Half) {
|
||||||
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
||||||
key.data_ptr<scalar_t>(),
|
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||||
value.data_ptr<scalar_t>(),
|
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
key_cache.data_ptr<scalar_t>(),
|
}
|
||||||
value_cache.data_ptr<scalar_t>(),
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||||
slot_mapping.data_ptr<int64_t>(),
|
if (key.dtype() == at::ScalarType::Float) {
|
||||||
key_stride,
|
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||||
value_stride,
|
} else if (key.dtype() == at::ScalarType::Half) {
|
||||||
num_heads,
|
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
||||||
head_size,
|
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||||
block_size,
|
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
||||||
x);
|
}
|
||||||
});
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@ -373,7 +402,7 @@ void gather_cached_kv(
|
|||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||||
key.scalar_type(),
|
key.scalar_type(),
|
||||||
"gather_cached_kv_kernel_optimized",
|
"gather_cached_kv_kernel_optimized",
|
||||||
[&] {
|
[&] {
|
||||||
@ -391,3 +420,55 @@ void gather_cached_kv(
|
|||||||
x);
|
x);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template<typename Tout, typename Tin>
|
||||||
|
__global__ void convert_fp8_e5m2_kernel(
|
||||||
|
const Tin* __restrict__ src_cache,
|
||||||
|
Tout* __restrict__ dst_cache,
|
||||||
|
const int64_t block_stride) {
|
||||||
|
const int64_t block_idx = blockIdx.x;
|
||||||
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
|
int64_t idx = block_idx * block_stride + i;
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
||||||
|
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||||
|
block_stride);
|
||||||
|
|
||||||
|
void convert_fp8_e5m2(
|
||||||
|
torch::Tensor& src_cache,
|
||||||
|
torch::Tensor& dst_cache)
|
||||||
|
{
|
||||||
|
int64_t num_blocks = src_cache.size(0);
|
||||||
|
int64_t block_stride = src_cache.stride(0);
|
||||||
|
|
||||||
|
dim3 grid(num_blocks);
|
||||||
|
dim3 block(std::min(block_stride, int64_t(512)));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
||||||
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
||||||
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -14,3 +14,13 @@
|
|||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH( \
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||||
|
@ -13,7 +13,8 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
@ -29,7 +30,8 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
|
@ -75,6 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
"gather_cached_kv",
|
"gather_cached_kv",
|
||||||
&gather_cached_kv,
|
&gather_cached_kv,
|
||||||
"Gather key and value from the cache into contiguous QKV tensors");
|
"Gather key and value from the cache into contiguous QKV tensors");
|
||||||
|
cache_ops.def(
|
||||||
|
"convert_fp8_e5m2",
|
||||||
|
&convert_fp8_e5m2,
|
||||||
|
"Convert the key and value cache to fp8_e5m2 data type");
|
||||||
|
|
||||||
// Cuda utils
|
// Cuda utils
|
||||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||||
|
278
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
278
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include <type_traits>
|
||||||
|
#include "../../attention/attention_dtypes.h"
|
||||||
|
#include "../../attention/dtype_float32.cuh"
|
||||||
|
#include "../../attention/dtype_float16.cuh"
|
||||||
|
#include "../../attention/dtype_bfloat16.cuh"
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
namespace fp8_e5m2_unscaled {
|
||||||
|
|
||||||
|
template<typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||||
|
{
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||||
|
return res.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
||||||
|
tmp.u16[0] = res.x;
|
||||||
|
tmp.u16[1] = res.y;
|
||||||
|
return tmp.u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||||
|
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||||
|
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
// Note there is no direct convert function from fp8 to bf16.
|
||||||
|
// fp8 -> half
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||||
|
// half -> float -> bf16
|
||||||
|
float tmp = half_to_float(res.x);
|
||||||
|
return __float2bfloat16(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
// fp8 -> half
|
||||||
|
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
||||||
|
// half -> float
|
||||||
|
return half_to_float(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
// fp8x2 -> half2
|
||||||
|
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
||||||
|
// half2 -> float2
|
||||||
|
return half2_to_float2(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
Float4_ res;
|
||||||
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
__half_raw tmp;
|
||||||
|
tmp.x = a;
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
||||||
|
return (uint8_t)res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||||
|
{
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
half2 float16;
|
||||||
|
uint32_t uint32;
|
||||||
|
};
|
||||||
|
|
||||||
|
float16 = __float22half2_rn(a);
|
||||||
|
return uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||||
|
{
|
||||||
|
uint2 b;
|
||||||
|
float2 val;
|
||||||
|
val.x = a.x.x;
|
||||||
|
val.y = a.x.y;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(val);
|
||||||
|
|
||||||
|
val.x = a.y.x;
|
||||||
|
val.y = a.y.y;
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(val);
|
||||||
|
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||||
|
{
|
||||||
|
float4 b;
|
||||||
|
b.x = a.x.x;
|
||||||
|
b.y = a.x.y;
|
||||||
|
b.z = a.y.x;
|
||||||
|
b.w = a.y.y;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||||
|
{
|
||||||
|
uint4 b;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||||
|
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||||
|
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
|
||||||
|
__nv_bfloat162 b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
||||||
|
bf16_4_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
||||||
|
bf16_8_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fp8_e5m2_unscaled
|
||||||
|
#endif // ENABLE_FP8_E5M2
|
||||||
|
} // namespace vllm
|
32
docs/source/quantization/fp8_e5m2_kv_cache.rst
Normal file
32
docs/source/quantization/fp8_e5m2_kv_cache.rst
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
.. _fp8_e5m2_kv_cache:
|
||||||
|
|
||||||
|
FP8 E5M2 KV Cache
|
||||||
|
==================
|
||||||
|
|
||||||
|
The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits.
|
||||||
|
The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other.
|
||||||
|
|
||||||
|
Here is an example of how to enable this feature:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2")
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
3
setup.py
3
setup.py
@ -253,6 +253,9 @@ if _is_cuda():
|
|||||||
num_threads = min(os.cpu_count(), nvcc_threads)
|
num_threads = min(os.cpu_count(), nvcc_threads)
|
||||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||||
|
|
||||||
|
if nvcc_cuda_version >= Version("11.8"):
|
||||||
|
NVCC_FLAGS += ["-DENABLE_FP8_E5M2"]
|
||||||
|
|
||||||
# changes for punica kernels
|
# changes for punica kernels
|
||||||
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
|
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
|
||||||
REMOVE_NVCC_FLAGS = [
|
REMOVE_NVCC_FLAGS = [
|
||||||
|
@ -1,44 +1,7 @@
|
|||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
from vllm.utils import create_kv_caches_with_random
|
||||||
|
|
||||||
|
|
||||||
def create_kv_caches(
|
|
||||||
num_blocks: int,
|
|
||||||
block_size: int,
|
|
||||||
num_layers: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
||||||
torch.random.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
|
|
||||||
scale = head_size**-0.5
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
|
||||||
key_caches = []
|
|
||||||
for _ in range(num_layers):
|
|
||||||
key_cache = torch.empty(size=key_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
key_cache.uniform_(-scale, scale)
|
|
||||||
key_caches.append(key_cache)
|
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
|
||||||
value_caches = []
|
|
||||||
for _ in range(num_layers):
|
|
||||||
value_cache = torch.empty(size=value_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
value_cache.uniform_(-scale, scale)
|
|
||||||
value_caches.append(value_cache)
|
|
||||||
return key_caches, value_caches
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def kv_cache_factory():
|
def kv_cache_factory():
|
||||||
return create_kv_caches
|
return create_kv_caches_with_random
|
||||||
|
@ -6,14 +6,16 @@ import torch
|
|||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm._C import ops, cache_ops
|
||||||
from vllm.utils import get_max_shared_memory_bytes
|
from vllm.utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
# This will change depending on the compute capability.
|
# This will change depending on the compute capability.
|
||||||
# - 512 as a buffer
|
# - 512 as a buffer
|
||||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||||
NUM_BLOCKS = 12000 # Arbitrary values for testing
|
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
||||||
|
# Reduce NUM_BLOCKS when it happens.
|
||||||
|
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
@ -23,6 +25,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
|||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
USE_ALIBI = [False, True]
|
USE_ALIBI = [False, True]
|
||||||
|
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||||
|
|
||||||
@ -105,6 +108,7 @@ def ref_single_query_cached_kv_attention(
|
|||||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
def test_paged_attention(
|
def test_paged_attention(
|
||||||
@ -116,6 +120,7 @@ def test_paged_attention(
|
|||||||
use_alibi: bool,
|
use_alibi: bool,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
device: int,
|
device: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -158,8 +163,9 @@ def test_paged_attention(
|
|||||||
|
|
||||||
# Create the KV caches.
|
# Create the KV caches.
|
||||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||||
num_kv_heads, head_size, dtype,
|
num_kv_heads, head_size,
|
||||||
seed, gpu_id)
|
kv_cache_dtype, dtype, seed,
|
||||||
|
gpu_id)
|
||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Call the paged attention kernel.
|
# Call the paged attention kernel.
|
||||||
@ -177,6 +183,7 @@ def test_paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||||
@ -209,11 +216,30 @@ def test_paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Unknown version: {version}")
|
raise AssertionError(f"Unknown version: {version}")
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
|
if kv_cache_dtype == "fp8_e5m2":
|
||||||
|
# Convert cache data back to dtype.
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||||
|
block_size, x)
|
||||||
|
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=gpu_id)
|
||||||
|
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
|
||||||
|
key_cache = dequantized_key_cache
|
||||||
|
|
||||||
|
value_cache_shape = value_cache.shape
|
||||||
|
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=gpu_id)
|
||||||
|
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
|
||||||
|
value_cache = dequantized_value_cache
|
||||||
|
|
||||||
ref_output = torch.empty_like(query)
|
ref_output = torch.empty_like(query)
|
||||||
ref_single_query_cached_kv_attention(
|
ref_single_query_cached_kv_attention(
|
||||||
ref_output,
|
ref_output,
|
||||||
@ -230,7 +256,12 @@ def test_paged_attention(
|
|||||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||||
# implementations, there is a small numerical difference in the two
|
# implementations, there is a small numerical difference in the two
|
||||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||||
|
# so we use a relaxed tolerance for the test.
|
||||||
|
atol, rtol = 1e-3, 1e-5
|
||||||
|
if kv_cache_dtype == "fp8_e5m2":
|
||||||
|
atol, rtol = 1e-2, 1e-5
|
||||||
|
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
def ref_multi_query_kv_attention(
|
def ref_multi_query_kv_attention(
|
||||||
|
@ -15,6 +15,7 @@ NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
|
|||||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||||
|
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||||
@ -26,6 +27,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_copy_blocks(
|
def test_copy_blocks(
|
||||||
kv_cache_factory,
|
kv_cache_factory,
|
||||||
@ -38,6 +40,7 @@ def test_copy_blocks(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
device: int,
|
device: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
@ -59,7 +62,8 @@ def test_copy_blocks(
|
|||||||
# Create the KV caches.
|
# Create the KV caches.
|
||||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||||
num_layers, num_heads,
|
num_layers, num_heads,
|
||||||
head_size, dtype, seed, gpu_id)
|
head_size, kv_cache_dtype,
|
||||||
|
dtype, seed, gpu_id)
|
||||||
|
|
||||||
# Clone the KV caches.
|
# Clone the KV caches.
|
||||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||||
@ -124,7 +128,7 @@ def test_reshape_and_cache(
|
|||||||
# Create the KV caches.
|
# Create the KV caches.
|
||||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||||
num_heads, head_size, dtype,
|
num_heads, head_size, dtype,
|
||||||
seed, gpu_id)
|
None, seed, gpu_id)
|
||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Clone the KV caches.
|
# Clone the KV caches.
|
||||||
@ -133,7 +137,7 @@ def test_reshape_and_cache(
|
|||||||
|
|
||||||
# Call the reshape_and_cache kernel.
|
# Call the reshape_and_cache kernel.
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||||
slot_mapping)
|
slot_mapping, "auto")
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
from typing import Optional, Union, ClassVar
|
from typing import Optional, Union, ClassVar
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import os
|
import os
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.config import get_config
|
from vllm.transformers_utils.config import get_config
|
||||||
from vllm.utils import get_cpu_memory, is_hip
|
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -275,6 +276,7 @@ class CacheConfig:
|
|||||||
gpu_memory_utilization: Fraction of GPU memory to use for the
|
gpu_memory_utilization: Fraction of GPU memory to use for the
|
||||||
vLLM execution.
|
vLLM execution.
|
||||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||||
|
cache_dtype: Data type for kv cache storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -282,13 +284,16 @@ class CacheConfig:
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
gpu_memory_utilization: float,
|
gpu_memory_utilization: float,
|
||||||
swap_space: int,
|
swap_space: int,
|
||||||
|
cache_dtype: str,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.gpu_memory_utilization = gpu_memory_utilization
|
self.gpu_memory_utilization = gpu_memory_utilization
|
||||||
self.swap_space_bytes = swap_space * _GB
|
self.swap_space_bytes = swap_space * _GB
|
||||||
|
self.cache_dtype = cache_dtype
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
self._verify_cache_dtype()
|
||||||
|
|
||||||
# Will be set after profiling.
|
# Will be set after profiling.
|
||||||
self.num_gpu_blocks = None
|
self.num_gpu_blocks = None
|
||||||
@ -300,6 +305,28 @@ class CacheConfig:
|
|||||||
"GPU memory utilization must be less than 1.0. Got "
|
"GPU memory utilization must be less than 1.0. Got "
|
||||||
f"{self.gpu_memory_utilization}.")
|
f"{self.gpu_memory_utilization}.")
|
||||||
|
|
||||||
|
def _verify_cache_dtype(self) -> None:
|
||||||
|
if self.cache_dtype == "auto":
|
||||||
|
pass
|
||||||
|
elif self.cache_dtype == "fp8_e5m2":
|
||||||
|
nvcc_cuda_version = get_nvcc_cuda_version()
|
||||||
|
if nvcc_cuda_version < Version("11.8"):
|
||||||
|
raise ValueError(
|
||||||
|
"FP8 is not supported when cuda version is lower than 11.8."
|
||||||
|
)
|
||||||
|
device_name = torch.cuda.get_device_name()
|
||||||
|
if "AMD" in device_name:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
|
||||||
|
logger.info(
|
||||||
|
"Using fp8_e5m2 data type to store kv cache. It reduces "
|
||||||
|
"the GPU memory footprint and boosts the performance. "
|
||||||
|
"But it may cause slight accuracy drop. "
|
||||||
|
"Currently we only support fp8 without scaling factors and "
|
||||||
|
"make e5m2 as a default format.")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||||
|
|
||||||
def verify_with_parallel_config(
|
def verify_with_parallel_config(
|
||||||
self,
|
self,
|
||||||
parallel_config: "ParallelConfig",
|
parallel_config: "ParallelConfig",
|
||||||
|
@ -17,6 +17,7 @@ class EngineArgs:
|
|||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
load_format: str = 'auto'
|
load_format: str = 'auto'
|
||||||
dtype: str = 'auto'
|
dtype: str = 'auto'
|
||||||
|
kv_cache_dtype: str = 'auto'
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
max_model_len: Optional[int] = None
|
max_model_len: Optional[int] = None
|
||||||
worker_use_ray: bool = False
|
worker_use_ray: bool = False
|
||||||
@ -122,6 +123,14 @@ class EngineArgs:
|
|||||||
'The "auto" option will use FP16 precision '
|
'The "auto" option will use FP16 precision '
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
'for BF16 models.')
|
'for BF16 models.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--kv-cache-dtype',
|
||||||
|
type=str,
|
||||||
|
choices=['auto', 'fp8_e5m2'],
|
||||||
|
default='auto',
|
||||||
|
help='Data type for kv cache storage. If "auto", will use model '
|
||||||
|
'data type. Note FP8 is not supported when cuda version is '
|
||||||
|
'lower than 11.8.')
|
||||||
parser.add_argument('--max-model-len',
|
parser.add_argument('--max-model-len',
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
@ -269,7 +278,7 @@ class EngineArgs:
|
|||||||
self.max_context_len_to_capture)
|
self.max_context_len_to_capture)
|
||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space,
|
self.swap_space, self.kv_cache_dtype,
|
||||||
model_config.get_sliding_window())
|
model_config.get_sliding_window())
|
||||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size,
|
self.tensor_parallel_size,
|
||||||
|
@ -85,6 +85,7 @@ class LLMEngine:
|
|||||||
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
|
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
|
||||||
f"quantization={model_config.quantization}, "
|
f"quantization={model_config.quantization}, "
|
||||||
f"enforce_eager={model_config.enforce_eager}, "
|
f"enforce_eager={model_config.enforce_eager}, "
|
||||||
|
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
||||||
f"seed={model_config.seed})")
|
f"seed={model_config.seed})")
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|
||||||
@ -144,6 +145,7 @@ class LLMEngine:
|
|||||||
rank=0,
|
rank=0,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
self._run_workers("init_model")
|
self._run_workers("init_model")
|
||||||
@ -234,6 +236,7 @@ class LLMEngine:
|
|||||||
model_config = copy.deepcopy(self.model_config)
|
model_config = copy.deepcopy(self.model_config)
|
||||||
parallel_config = copy.deepcopy(self.parallel_config)
|
parallel_config = copy.deepcopy(self.parallel_config)
|
||||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||||
|
cache_config = copy.deepcopy(self.cache_config)
|
||||||
|
|
||||||
for rank, (worker, (node_id,
|
for rank, (worker, (node_id,
|
||||||
_)) in enumerate(zip(self.workers,
|
_)) in enumerate(zip(self.workers,
|
||||||
@ -249,6 +252,7 @@ class LLMEngine:
|
|||||||
rank,
|
rank,
|
||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
|
cache_config=cache_config,
|
||||||
))
|
))
|
||||||
|
|
||||||
driver_rank = 0
|
driver_rank = 0
|
||||||
@ -261,6 +265,7 @@ class LLMEngine:
|
|||||||
driver_rank,
|
driver_rank,
|
||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
|
cache_config=cache_config,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -306,6 +311,7 @@ class LLMEngine:
|
|||||||
block_size=self.cache_config.block_size,
|
block_size=self.cache_config.block_size,
|
||||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||||
|
cache_dtype=self.cache_config.cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Since we use a shared centralized controller, we take the minimum
|
# Since we use a shared centralized controller, we take the minimum
|
||||||
|
@ -12,6 +12,7 @@ class InputMetadata:
|
|||||||
max_context_len: The maximum context length.
|
max_context_len: The maximum context length.
|
||||||
context_lens: the length of attention context for each sequence.
|
context_lens: the length of attention context for each sequence.
|
||||||
block_tables: The block tables. (Seq id -> list of physical block)
|
block_tables: The block tables. (Seq id -> list of physical block)
|
||||||
|
kv_cache_dtype: Data type to store kv cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -25,6 +26,7 @@ class InputMetadata:
|
|||||||
context_lens: Optional[torch.Tensor],
|
context_lens: Optional[torch.Tensor],
|
||||||
block_tables: Optional[torch.Tensor],
|
block_tables: Optional[torch.Tensor],
|
||||||
use_cuda_graph: bool,
|
use_cuda_graph: bool,
|
||||||
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.is_prompt = is_prompt
|
self.is_prompt = is_prompt
|
||||||
self.prompt_lens = prompt_lens
|
self.prompt_lens = prompt_lens
|
||||||
@ -35,6 +37,7 @@ class InputMetadata:
|
|||||||
self.context_lens = context_lens
|
self.context_lens = context_lens
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
self.use_cuda_graph = use_cuda_graph
|
self.use_cuda_graph = use_cuda_graph
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
# Set during the execution of the first attention op.
|
# Set during the execution of the first attention op.
|
||||||
# FIXME(woosuk): This is a hack.
|
# FIXME(woosuk): This is a hack.
|
||||||
@ -47,4 +50,5 @@ class InputMetadata:
|
|||||||
f"slot_mapping={self.slot_mapping}, "
|
f"slot_mapping={self.slot_mapping}, "
|
||||||
f"context_lens={self.context_lens}, "
|
f"context_lens={self.context_lens}, "
|
||||||
f"block_tables={self.block_tables}, "
|
f"block_tables={self.block_tables}, "
|
||||||
f"use_cuda_graph={self.use_cuda_graph})")
|
f"use_cuda_graph={self.use_cuda_graph}, "
|
||||||
|
f"kv_cache_dtype={self.kv_cache_dtype})")
|
||||||
|
@ -98,6 +98,7 @@ class PagedAttention(nn.Module):
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
input_metadata.slot_mapping.flatten(),
|
input_metadata.slot_mapping.flatten(),
|
||||||
|
input_metadata.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_metadata.is_prompt:
|
if input_metadata.is_prompt:
|
||||||
@ -265,6 +266,7 @@ def _paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
input_metadata.max_context_len,
|
input_metadata.max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
input_metadata.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -295,5 +297,6 @@ def _paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
input_metadata.max_context_len,
|
input_metadata.max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
input_metadata.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
110
vllm/utils.py
110
vllm/utils.py
@ -1,9 +1,11 @@
|
|||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
import subprocess
|
||||||
import uuid
|
import uuid
|
||||||
from platform import uname
|
from platform import uname
|
||||||
from typing import List
|
from typing import List, Tuple, Union
|
||||||
|
from packaging.version import parse, Version
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
@ -17,7 +19,17 @@ from typing import (
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Hashable, Optional
|
from typing import Any, Hashable, Optional
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
"half": torch.half,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float": torch.float,
|
||||||
|
"fp8_e5m2": torch.uint8,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
class Device(enum.Enum):
|
||||||
@ -167,3 +179,99 @@ def get_open_port() -> int:
|
|||||||
|
|
||||||
def set_cuda_visible_devices(device_ids: List[int]) -> None:
|
def set_cuda_visible_devices(device_ids: List[int]) -> None:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
|
||||||
|
|
||||||
|
|
||||||
|
def get_nvcc_cuda_version() -> Version:
|
||||||
|
cuda_home = os.environ.get('CUDA_HOME')
|
||||||
|
if not cuda_home:
|
||||||
|
cuda_home = '/usr/local/cuda'
|
||||||
|
logger.info(
|
||||||
|
f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
|
||||||
|
)
|
||||||
|
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
|
||||||
|
universal_newlines=True)
|
||||||
|
output = nvcc_output.split()
|
||||||
|
release_idx = output.index("release") + 1
|
||||||
|
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
|
||||||
|
return nvcc_cuda_version
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_random_fp8_e5m2(
|
||||||
|
tensor: torch.tensor,
|
||||||
|
low: float,
|
||||||
|
high: float,
|
||||||
|
) -> None:
|
||||||
|
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
|
||||||
|
# it may occur Inf or NaN if we directly use torch.randint
|
||||||
|
# to generate random data for fp8 data.
|
||||||
|
# For example, s.11111.00 in fp8e5m2 format repesents Inf.
|
||||||
|
# | E4M3 | E5M2
|
||||||
|
#-----|-------------|-------------------
|
||||||
|
# Inf | N/A | s.11111.00
|
||||||
|
# NaN | s.1111.111 | s.11111.{01,10,11}
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
||||||
|
tensor_tmp.uniform_(low, high)
|
||||||
|
cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
|
||||||
|
del tensor_tmp
|
||||||
|
|
||||||
|
|
||||||
|
def create_kv_caches_with_random(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype: Optional[Union[str, torch.dtype]],
|
||||||
|
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
seed: Optional[int] = 0,
|
||||||
|
device: Optional[str] = "cuda",
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
if isinstance(cache_dtype, str):
|
||||||
|
if cache_dtype == "auto":
|
||||||
|
if isinstance(model_dtype, str):
|
||||||
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
||||||
|
elif isinstance(model_dtype, torch.dtype):
|
||||||
|
torch_dtype = model_dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model dtype: {model_dtype}")
|
||||||
|
elif cache_dtype in ["half", "bfloat16", "float"]:
|
||||||
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||||
|
elif cache_dtype == "fp8_e5m2":
|
||||||
|
torch_dtype = torch.uint8
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||||
|
elif isinstance(cache_dtype, torch.dtype):
|
||||||
|
torch_dtype = cache_dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||||
|
|
||||||
|
scale = head_size**-0.5
|
||||||
|
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
|
||||||
|
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||||
|
key_caches = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
key_cache = torch.empty(size=key_cache_shape,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
device=device)
|
||||||
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||||
|
key_cache.uniform_(-scale, scale)
|
||||||
|
elif cache_dtype == 'fp8_e5m2':
|
||||||
|
_generate_random_fp8_e5m2(key_cache, -scale, scale)
|
||||||
|
key_caches.append(key_cache)
|
||||||
|
|
||||||
|
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||||
|
value_caches = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
value_cache = torch.empty(size=value_cache_shape,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
device=device)
|
||||||
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||||
|
value_cache.uniform_(-scale, scale)
|
||||||
|
elif cache_dtype == 'fp8_e5m2':
|
||||||
|
_generate_random_fp8_e5m2(value_cache, -scale, scale)
|
||||||
|
value_caches.append(value_cache)
|
||||||
|
return key_caches, value_caches
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
from vllm._C import cache_ops
|
from vllm._C import cache_ops
|
||||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import in_wsl
|
from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -34,12 +34,16 @@ class CacheEngine:
|
|||||||
self.head_size = model_config.get_head_size()
|
self.head_size = model_config.get_head_size()
|
||||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||||
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
self.dtype = model_config.dtype
|
|
||||||
|
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||||
|
|
||||||
|
if cache_config.cache_dtype == "auto":
|
||||||
|
self.dtype = model_config.dtype
|
||||||
|
else:
|
||||||
|
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||||
|
|
||||||
# Initialize the cache.
|
# Initialize the cache.
|
||||||
self.gpu_cache = self.allocate_gpu_cache()
|
self.gpu_cache = self.allocate_gpu_cache()
|
||||||
self.cpu_cache = self.allocate_cpu_cache()
|
self.cpu_cache = self.allocate_cpu_cache()
|
||||||
@ -142,6 +146,7 @@ class CacheEngine:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cache_block_size(
|
def get_cache_block_size(
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
cache_dtype: str,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -152,7 +157,11 @@ class CacheEngine:
|
|||||||
key_cache_block = block_size * num_heads * head_size
|
key_cache_block = block_size * num_heads * head_size
|
||||||
value_cache_block = key_cache_block
|
value_cache_block = key_cache_block
|
||||||
total = num_layers * (key_cache_block + value_cache_block)
|
total = num_layers * (key_cache_block + value_cache_block)
|
||||||
dtype_size = _get_dtype_size(model_config.dtype)
|
if cache_dtype == "auto":
|
||||||
|
dtype = model_config.dtype
|
||||||
|
else:
|
||||||
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||||
|
dtype_size = _get_dtype_size(dtype)
|
||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ class ModelRunner:
|
|||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
):
|
):
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
@ -68,6 +69,7 @@ class ModelRunner:
|
|||||||
self.graph_block_tables = None # Set after initial profiling.
|
self.graph_block_tables = None # Set after initial profiling.
|
||||||
# cache in_wsl result
|
# cache in_wsl result
|
||||||
self.in_wsl = in_wsl()
|
self.in_wsl = in_wsl()
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.model = get_model(self.model_config, self.lora_config)
|
self.model = get_model(self.model_config, self.lora_config)
|
||||||
@ -223,6 +225,7 @@ class ModelRunner:
|
|||||||
context_lens=context_lens_tensor,
|
context_lens=context_lens_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
return (input_tokens, input_positions, input_metadata, prompt_lens,
|
return (input_tokens, input_positions, input_metadata, prompt_lens,
|
||||||
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||||
@ -350,6 +353,7 @@ class ModelRunner:
|
|||||||
context_lens=context_lens,
|
context_lens=context_lens,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
|
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
|
||||||
|
|
||||||
@ -473,6 +477,7 @@ class ModelRunner:
|
|||||||
"context_lens": input_metadata.context_lens,
|
"context_lens": input_metadata.context_lens,
|
||||||
"block_tables": input_metadata.block_tables,
|
"block_tables": input_metadata.block_tables,
|
||||||
"use_cuda_graph": input_metadata.use_cuda_graph,
|
"use_cuda_graph": input_metadata.use_cuda_graph,
|
||||||
|
"kv_cache_dtype": input_metadata.kv_cache_dtype,
|
||||||
"selected_token_indices":
|
"selected_token_indices":
|
||||||
sampling_metadata.selected_token_indices,
|
sampling_metadata.selected_token_indices,
|
||||||
"lora_requests": lora_requests,
|
"lora_requests": lora_requests,
|
||||||
@ -495,6 +500,7 @@ class ModelRunner:
|
|||||||
context_lens=metadata_dict["context_lens"],
|
context_lens=metadata_dict["context_lens"],
|
||||||
block_tables=metadata_dict["block_tables"],
|
block_tables=metadata_dict["block_tables"],
|
||||||
use_cuda_graph=metadata_dict["use_cuda_graph"],
|
use_cuda_graph=metadata_dict["use_cuda_graph"],
|
||||||
|
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
|
||||||
)
|
)
|
||||||
sampling_metadata = SamplingMetadata(
|
sampling_metadata = SamplingMetadata(
|
||||||
seq_groups=None,
|
seq_groups=None,
|
||||||
@ -665,6 +671,7 @@ class ModelRunner:
|
|||||||
context_lens=context_lens[:batch_size],
|
context_lens=context_lens[:batch_size],
|
||||||
block_tables=block_tables[:batch_size],
|
block_tables=block_tables[:batch_size],
|
||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
|
@ -37,6 +37,7 @@ class Worker:
|
|||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
@ -54,6 +55,7 @@ class Worker:
|
|||||||
parallel_config,
|
parallel_config,
|
||||||
scheduler_config,
|
scheduler_config,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
is_driver_worker=is_driver_worker)
|
is_driver_worker=is_driver_worker)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# self.init_cache_engine().
|
# self.init_cache_engine().
|
||||||
@ -95,6 +97,7 @@ class Worker:
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
gpu_memory_utilization: float,
|
gpu_memory_utilization: float,
|
||||||
cpu_swap_space: int,
|
cpu_swap_space: int,
|
||||||
|
cache_dtype: str,
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
"""Profiles the peak memory usage of the model and returns the maximum
|
"""Profiles the peak memory usage of the model and returns the maximum
|
||||||
number of GPU and CPU cache blocks that can be allocated.
|
number of GPU and CPU cache blocks that can be allocated.
|
||||||
@ -119,7 +122,7 @@ class Worker:
|
|||||||
peak_memory = total_gpu_memory - free_gpu_memory
|
peak_memory = total_gpu_memory - free_gpu_memory
|
||||||
|
|
||||||
cache_block_size = CacheEngine.get_cache_block_size(
|
cache_block_size = CacheEngine.get_cache_block_size(
|
||||||
block_size, self.model_config, self.parallel_config)
|
block_size, cache_dtype, self.model_config, self.parallel_config)
|
||||||
num_gpu_blocks = int(
|
num_gpu_blocks = int(
|
||||||
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
|
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
|
||||||
cache_block_size)
|
cache_block_size)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user