Support FP8-E5M2 KV Cache (#2279)
Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
7d648418b8
commit
9090bf02e7
@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
dtype=args.dtype,
|
||||
enforce_eager=args.enforce_eager,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@ -117,6 +118,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='enforce eager mode and disable CUDA graph')
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=['auto', 'fp8_e5m2'],
|
||||
default='auto',
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument(
|
||||
'--profile',
|
||||
action='store_true',
|
||||
|
@ -71,6 +71,7 @@ def run_vllm(
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
@ -83,6 +84,7 @@ def run_vllm(
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -206,7 +208,8 @@ def main(args: argparse.Namespace):
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len, args.enforce_eager)
|
||||
args.max_model_len, args.enforce_eager,
|
||||
args.kv_cache_dtype)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -284,6 +287,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--enforce-eager",
|
||||
action="store_true",
|
||||
help="enforce eager execution")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
@ -1,9 +1,11 @@
|
||||
from typing import Optional
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||
from vllm._C import ops
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
@ -23,6 +25,7 @@ def main(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
do_profile: bool,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
@ -59,15 +62,10 @@ def main(
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||
|
||||
# Create the KV cache.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
|
||||
key_cache.uniform_(-scale, scale)
|
||||
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
value_cache.uniform_(-scale, scale)
|
||||
key_caches, value_caches = create_kv_caches_with_random(
|
||||
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
|
||||
dtype)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
@ -106,6 +104,7 @@ def main(
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
elif version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
@ -123,6 +122,7 @@ def main(
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
@ -168,16 +168,18 @@ if __name__ == '__main__':
|
||||
default="half")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.num_query_heads % args.num_kv_heads != 0:
|
||||
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
||||
dtype_to_torch_dtype = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
}
|
||||
main(
|
||||
version=args.version,
|
||||
num_seqs=args.batch_size,
|
||||
@ -187,7 +189,8 @@ if __name__ == '__main__':
|
||||
head_size=args.head_size,
|
||||
block_size=args.block_size,
|
||||
use_alibi=args.use_alibi,
|
||||
dtype=dtype_to_torch_dtype[args.dtype],
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
)
|
||||
|
@ -4,3 +4,4 @@
|
||||
#include "dtype_float16.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
#include "dtype_bfloat16.cuh"
|
||||
#include "dtype_fp8_e5m2.cuh"
|
||||
|
@ -25,6 +25,7 @@
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -79,17 +80,19 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||
__device__ void paged_attention_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
@ -145,6 +148,9 @@ __device__ void paged_attention_kernel(
|
||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||
#endif
|
||||
|
||||
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||
@ -176,7 +182,7 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
||||
// Each thread group fetches x elements from the key at a time.
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
constexpr int x = 16 / sizeof(cache_t);
|
||||
float qk_max = -FLT_MAX;
|
||||
|
||||
// Iterate over the key blocks.
|
||||
@ -202,13 +208,23 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
// Vector conversion from Quant_vec to K_vec.
|
||||
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute dot product.
|
||||
@ -282,6 +298,9 @@ __device__ void paged_attention_kernel(
|
||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||
#endif
|
||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||
|
||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||
@ -307,14 +326,25 @@ __device__ void paged_attention_kernel(
|
||||
L_vec logits_vec;
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||
|
||||
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride;
|
||||
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE) {
|
||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
V_vec v_vec;
|
||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
}
|
||||
if (block_idx == num_context_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||
// we should explicitly zero out the values since they may contain NaNs.
|
||||
@ -395,14 +425,16 @@ __device__ void paged_attention_kernel(
|
||||
// Grid: (num_heads, num_seqs, 1).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS>
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE>
|
||||
__global__ void paged_attention_v1_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
@ -412,7 +444,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||
@ -421,17 +453,19 @@ __global__ void paged_attention_v1_kernel(
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
@ -441,7 +475,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride);
|
||||
@ -550,10 +584,10 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
|
||||
shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
@ -571,7 +605,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int NUM_THREADS = 128>
|
||||
void paged_attention_v1_launcher(
|
||||
torch::Tensor& out,
|
||||
@ -602,8 +638,8 @@ void paged_attention_v1_launcher(
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
@ -647,35 +683,35 @@ void paged_attention_v1_launcher(
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
alibi_slopes);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V1_LAUNCHER(T, 8); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V1_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V1_LAUNCHER(T, 32); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void paged_attention_v1(
|
||||
@ -689,20 +725,36 @@ void paged_attention_v1(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
@ -730,7 +782,9 @@ void paged_attention_v1(
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int NUM_THREADS = 128,
|
||||
int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_launcher(
|
||||
@ -768,8 +822,8 @@ void paged_attention_v2_launcher(
|
||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
@ -816,38 +870,38 @@ void paged_attention_v2_launcher(
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_launcher<T, BLOCK_SIZE>( \
|
||||
out, \
|
||||
exp_sums, \
|
||||
max_logits, \
|
||||
tmp_out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||
out, \
|
||||
exp_sums, \
|
||||
max_logits, \
|
||||
tmp_out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
alibi_slopes);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V2_LAUNCHER(T, 8); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V2_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V2_LAUNCHER(T, 32); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void paged_attention_v2(
|
||||
@ -864,15 +918,30 @@ void paged_attention_v2(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
|
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.cuh"
|
||||
|
||||
#include <stdint.h>
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
// fp8 vector types for quantization of kv cache
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 1> {
|
||||
using Type = uint8_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 2> {
|
||||
using Type = uint16_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 4> {
|
||||
using Type = uint32_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 8> {
|
||||
using Type = uint2;
|
||||
};
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
|
||||
} // namespace vllm
|
@ -20,7 +20,8 @@ void reshape_and_cache(
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key,
|
||||
@ -28,3 +29,8 @@ void gather_cached_kv(
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8_e5m2(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache);
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@ -131,7 +132,7 @@ void copy_blocks(
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
@ -143,12 +144,12 @@ void copy_blocks(
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
@ -185,19 +186,45 @@ __global__ void reshape_and_cache_kernel(
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
key_cache[tgt_key_idx] = key[src_key_idx];
|
||||
value_cache[tgt_value_idx] = value[src_value_idx];
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (is_fp8_e5m2_kv_cache) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
key_cache[tgt_key_idx] = tgt_key;
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), \
|
||||
key_stride, \
|
||||
value_stride, \
|
||||
num_heads, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
x);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [num_tokens]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
@ -212,23 +239,25 @@ void reshape_and_cache(
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
});
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, float, false);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
@ -256,12 +285,12 @@ __global__ void gather_cached_kv_kernel(
|
||||
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
|
||||
const int tgt_key_idx = token_idx * key_stride + i;
|
||||
const int tgt_value_idx = token_idx * value_stride + i;
|
||||
|
||||
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
|
||||
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
@ -373,7 +402,7 @@ void gather_cached_kv(
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
key.scalar_type(),
|
||||
"gather_cached_kv_kernel_optimized",
|
||||
[&] {
|
||||
@ -391,3 +420,55 @@ void gather_cached_kv(
|
||||
x);
|
||||
});
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__global__ void convert_fp8_e5m2_kernel(
|
||||
const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const int64_t block_stride) {
|
||||
const int64_t block_idx = blockIdx.x;
|
||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||
int64_t idx = block_idx * block_stride + i;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
||||
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||
block_stride);
|
||||
|
||||
void convert_fp8_e5m2(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache)
|
||||
{
|
||||
int64_t num_blocks = src_cache.size(0);
|
||||
int64_t block_stride = src_cache.stride(0);
|
||||
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(std::min(block_stride, int64_t(512)));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
||||
}
|
||||
}
|
||||
|
@ -14,3 +14,13 @@
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||
|
@ -13,7 +13,8 @@ void paged_attention_v1(
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
@ -29,7 +30,8 @@ void paged_attention_v2(
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
|
@ -75,6 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
"gather_cached_kv",
|
||||
&gather_cached_kv,
|
||||
"Gather key and value from the cache into contiguous QKV tensors");
|
||||
cache_ops.def(
|
||||
"convert_fp8_e5m2",
|
||||
&convert_fp8_e5m2,
|
||||
"Convert the key and value cache to fp8_e5m2 data type");
|
||||
|
||||
// Cuda utils
|
||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
|
278
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
278
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
@ -0,0 +1,278 @@
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <float.h>
|
||||
#include <type_traits>
|
||||
#include "../../attention/attention_dtypes.h"
|
||||
#include "../../attention/dtype_float32.cuh"
|
||||
#include "../../attention/dtype_float16.cuh"
|
||||
#include "../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace vllm {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
namespace fp8_e5m2_unscaled {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template<>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
||||
tmp.u16[0] = res.x;
|
||||
tmp.u16[1] = res.y;
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template<>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// fp8 -> half
|
||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
||||
// half -> float
|
||||
return half_to_float(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template<>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template<>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
// half -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||
{
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
{
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
{
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
{
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
{
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
|
||||
__nv_bfloat162 b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
||||
bf16_4_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
||||
bf16_8_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
} // namespace fp8_e5m2_unscaled
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
} // namespace vllm
|
32
docs/source/quantization/fp8_e5m2_kv_cache.rst
Normal file
32
docs/source/quantization/fp8_e5m2_kv_cache.rst
Normal file
@ -0,0 +1,32 @@
|
||||
.. _fp8_e5m2_kv_cache:
|
||||
|
||||
FP8 E5M2 KV Cache
|
||||
==================
|
||||
|
||||
The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits.
|
||||
The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other.
|
||||
|
||||
Here is an example of how to enable this feature:
|
||||
|
||||
.. code-block:: python
|
||||
from vllm import LLM, SamplingParams
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
3
setup.py
3
setup.py
@ -253,6 +253,9 @@ if _is_cuda():
|
||||
num_threads = min(os.cpu_count(), nvcc_threads)
|
||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||
|
||||
if nvcc_cuda_version >= Version("11.8"):
|
||||
NVCC_FLAGS += ["-DENABLE_FP8_E5M2"]
|
||||
|
||||
# changes for punica kernels
|
||||
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
|
||||
REMOVE_NVCC_FLAGS = [
|
||||
|
@ -1,44 +1,7 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def create_kv_caches(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = head_size**-0.5
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
key_cache.uniform_(-scale, scale)
|
||||
key_caches.append(key_cache)
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
value_cache.uniform_(-scale, scale)
|
||||
value_caches.append(value_cache)
|
||||
return key_caches, value_caches
|
||||
from vllm.utils import create_kv_caches_with_random
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kv_cache_factory():
|
||||
return create_kv_caches
|
||||
return create_kv_caches_with_random
|
||||
|
@ -6,14 +6,16 @@ import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm._C import ops, cache_ops
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
NUM_BLOCKS = 12000 # Arbitrary values for testing
|
||||
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
||||
# Reduce NUM_BLOCKS when it happens.
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -23,6 +25,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
@ -105,6 +108,7 @@ def ref_single_query_cached_kv_attention(
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_paged_attention(
|
||||
@ -116,6 +120,7 @@ def test_paged_attention(
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
@ -158,8 +163,9 @@ def test_paged_attention(
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size, dtype,
|
||||
seed, gpu_id)
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
gpu_id)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Call the paged attention kernel.
|
||||
@ -177,6 +183,7 @@ def test_paged_attention(
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
elif version == "v2":
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||
@ -209,11 +216,30 @@ def test_paged_attention(
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
if kv_cache_dtype == "fp8_e5m2":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
@ -230,7 +256,12 @@ def test_paged_attention(
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8_e5m2":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
|
@ -15,6 +15,7 @@ NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@ -26,6 +27,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks(
|
||||
kv_cache_factory,
|
||||
@ -38,6 +40,7 @@ def test_copy_blocks(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
@ -59,7 +62,8 @@ def test_copy_blocks(
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, dtype, seed, gpu_id)
|
||||
head_size, kv_cache_dtype,
|
||||
dtype, seed, gpu_id)
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
@ -124,7 +128,7 @@ def test_reshape_and_cache(
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size, dtype,
|
||||
seed, gpu_id)
|
||||
None, seed, gpu_id)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Clone the KV caches.
|
||||
@ -133,7 +137,7 @@ def test_reshape_and_cache(
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping)
|
||||
slot_mapping, "auto")
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
|
@ -1,13 +1,14 @@
|
||||
from typing import Optional, Union, ClassVar
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from packaging.version import Version
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import get_cpu_memory, is_hip
|
||||
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -275,6 +276,7 @@ class CacheConfig:
|
||||
gpu_memory_utilization: Fraction of GPU memory to use for the
|
||||
vLLM execution.
|
||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||
cache_dtype: Data type for kv cache storage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -282,13 +284,16 @@ class CacheConfig:
|
||||
block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: int,
|
||||
cache_dtype: str,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * _GB
|
||||
self.cache_dtype = cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
self._verify_args()
|
||||
self._verify_cache_dtype()
|
||||
|
||||
# Will be set after profiling.
|
||||
self.num_gpu_blocks = None
|
||||
@ -300,6 +305,28 @@ class CacheConfig:
|
||||
"GPU memory utilization must be less than 1.0. Got "
|
||||
f"{self.gpu_memory_utilization}.")
|
||||
|
||||
def _verify_cache_dtype(self) -> None:
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype == "fp8_e5m2":
|
||||
nvcc_cuda_version = get_nvcc_cuda_version()
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
raise ValueError(
|
||||
"FP8 is not supported when cuda version is lower than 11.8."
|
||||
)
|
||||
device_name = torch.cuda.get_device_name()
|
||||
if "AMD" in device_name:
|
||||
raise NotImplementedError(
|
||||
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
|
||||
logger.info(
|
||||
"Using fp8_e5m2 data type to store kv cache. It reduces "
|
||||
"the GPU memory footprint and boosts the performance. "
|
||||
"But it may cause slight accuracy drop. "
|
||||
"Currently we only support fp8 without scaling factors and "
|
||||
"make e5m2 as a default format.")
|
||||
else:
|
||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
|
@ -17,6 +17,7 @@ class EngineArgs:
|
||||
download_dir: Optional[str] = None
|
||||
load_format: str = 'auto'
|
||||
dtype: str = 'auto'
|
||||
kv_cache_dtype: str = 'auto'
|
||||
seed: int = 0
|
||||
max_model_len: Optional[int] = None
|
||||
worker_use_ray: bool = False
|
||||
@ -122,6 +123,14 @@ class EngineArgs:
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument(
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8_e5m2'],
|
||||
default='auto',
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. Note FP8 is not supported when cuda version is '
|
||||
'lower than 11.8.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
@ -269,7 +278,7 @@ class EngineArgs:
|
||||
self.max_context_len_to_capture)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
model_config.get_sliding_window())
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
|
@ -85,6 +85,7 @@ class LLMEngine:
|
||||
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"enforce_eager={model_config.enforce_eager}, "
|
||||
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
@ -144,6 +145,7 @@ class LLMEngine:
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self._run_workers("init_model")
|
||||
@ -234,6 +236,7 @@ class LLMEngine:
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
cache_config = copy.deepcopy(self.cache_config)
|
||||
|
||||
for rank, (worker, (node_id,
|
||||
_)) in enumerate(zip(self.workers,
|
||||
@ -249,6 +252,7 @@ class LLMEngine:
|
||||
rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
cache_config=cache_config,
|
||||
))
|
||||
|
||||
driver_rank = 0
|
||||
@ -261,6 +265,7 @@ class LLMEngine:
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
cache_config=cache_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
@ -306,6 +311,7 @@ class LLMEngine:
|
||||
block_size=self.cache_config.block_size,
|
||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
)
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
|
@ -12,6 +12,7 @@ class InputMetadata:
|
||||
max_context_len: The maximum context length.
|
||||
context_lens: the length of attention context for each sequence.
|
||||
block_tables: The block tables. (Seq id -> list of physical block)
|
||||
kv_cache_dtype: Data type to store kv cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -25,6 +26,7 @@ class InputMetadata:
|
||||
context_lens: Optional[torch.Tensor],
|
||||
block_tables: Optional[torch.Tensor],
|
||||
use_cuda_graph: bool,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
self.is_prompt = is_prompt
|
||||
self.prompt_lens = prompt_lens
|
||||
@ -35,6 +37,7 @@ class InputMetadata:
|
||||
self.context_lens = context_lens
|
||||
self.block_tables = block_tables
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# Set during the execution of the first attention op.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
@ -47,4 +50,5 @@ class InputMetadata:
|
||||
f"slot_mapping={self.slot_mapping}, "
|
||||
f"context_lens={self.context_lens}, "
|
||||
f"block_tables={self.block_tables}, "
|
||||
f"use_cuda_graph={self.use_cuda_graph})")
|
||||
f"use_cuda_graph={self.use_cuda_graph}, "
|
||||
f"kv_cache_dtype={self.kv_cache_dtype})")
|
||||
|
@ -98,6 +98,7 @@ class PagedAttention(nn.Module):
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.slot_mapping.flatten(),
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
|
||||
if input_metadata.is_prompt:
|
||||
@ -265,6 +266,7 @@ def _paged_attention(
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
@ -295,5 +297,6 @@ def _paged_attention(
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
return output
|
||||
|
110
vllm/utils.py
110
vllm/utils.py
@ -1,9 +1,11 @@
|
||||
import enum
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import uuid
|
||||
from platform import uname
|
||||
from typing import List
|
||||
from typing import List, Tuple, Union
|
||||
from packaging.version import parse, Version
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@ -17,7 +19,17 @@ from typing import (
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Hashable, Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = init_logger(__name__)
|
||||
|
||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
"fp8_e5m2": torch.uint8,
|
||||
}
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
@ -167,3 +179,99 @@ def get_open_port() -> int:
|
||||
|
||||
def set_cuda_visible_devices(device_ids: List[int]) -> None:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
|
||||
|
||||
|
||||
def get_nvcc_cuda_version() -> Version:
|
||||
cuda_home = os.environ.get('CUDA_HOME')
|
||||
if not cuda_home:
|
||||
cuda_home = '/usr/local/cuda'
|
||||
logger.info(
|
||||
f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
|
||||
)
|
||||
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
|
||||
universal_newlines=True)
|
||||
output = nvcc_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
|
||||
return nvcc_cuda_version
|
||||
|
||||
|
||||
def _generate_random_fp8_e5m2(
|
||||
tensor: torch.tensor,
|
||||
low: float,
|
||||
high: float,
|
||||
) -> None:
|
||||
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
|
||||
# it may occur Inf or NaN if we directly use torch.randint
|
||||
# to generate random data for fp8 data.
|
||||
# For example, s.11111.00 in fp8e5m2 format repesents Inf.
|
||||
# | E4M3 | E5M2
|
||||
#-----|-------------|-------------------
|
||||
# Inf | N/A | s.11111.00
|
||||
# NaN | s.1111.111 | s.11111.{01,10,11}
|
||||
from vllm._C import cache_ops
|
||||
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
||||
tensor_tmp.uniform_(low, high)
|
||||
cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
|
||||
del tensor_tmp
|
||||
|
||||
|
||||
def create_kv_caches_with_random(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype: Optional[Union[str, torch.dtype]],
|
||||
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
seed: Optional[int] = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
if isinstance(cache_dtype, str):
|
||||
if cache_dtype == "auto":
|
||||
if isinstance(model_dtype, str):
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
||||
elif isinstance(model_dtype, torch.dtype):
|
||||
torch_dtype = model_dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid model dtype: {model_dtype}")
|
||||
elif cache_dtype in ["half", "bfloat16", "float"]:
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||
elif cache_dtype == "fp8_e5m2":
|
||||
torch_dtype = torch.uint8
|
||||
else:
|
||||
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||
elif isinstance(cache_dtype, torch.dtype):
|
||||
torch_dtype = cache_dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||
|
||||
scale = head_size**-0.5
|
||||
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=torch_dtype,
|
||||
device=device)
|
||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||
key_cache.uniform_(-scale, scale)
|
||||
elif cache_dtype == 'fp8_e5m2':
|
||||
_generate_random_fp8_e5m2(key_cache, -scale, scale)
|
||||
key_caches.append(key_cache)
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=torch_dtype,
|
||||
device=device)
|
||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||
value_cache.uniform_(-scale, scale)
|
||||
elif cache_dtype == 'fp8_e5m2':
|
||||
_generate_random_fp8_e5m2(value_cache, -scale, scale)
|
||||
value_caches.append(value_cache)
|
||||
return key_caches, value_caches
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
from vllm._C import cache_ops
|
||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import in_wsl
|
||||
from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -34,12 +34,16 @@ class CacheEngine:
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.dtype = model_config.dtype
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.dtype = model_config.dtype
|
||||
else:
|
||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# Initialize the cache.
|
||||
self.gpu_cache = self.allocate_gpu_cache()
|
||||
self.cpu_cache = self.allocate_cpu_cache()
|
||||
@ -142,6 +146,7 @@ class CacheEngine:
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
block_size: int,
|
||||
cache_dtype: str,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
@ -152,7 +157,11 @@ class CacheEngine:
|
||||
key_cache_block = block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = _get_dtype_size(model_config.dtype)
|
||||
if cache_dtype == "auto":
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||
dtype_size = _get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
|
@ -36,6 +36,7 @@ class ModelRunner:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
self.model_config = model_config
|
||||
@ -68,6 +69,7 @@ class ModelRunner:
|
||||
self.graph_block_tables = None # Set after initial profiling.
|
||||
# cache in_wsl result
|
||||
self.in_wsl = in_wsl()
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(self.model_config, self.lora_config)
|
||||
@ -223,6 +225,7 @@ class ModelRunner:
|
||||
context_lens=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return (input_tokens, input_positions, input_metadata, prompt_lens,
|
||||
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||
@ -350,6 +353,7 @@ class ModelRunner:
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
|
||||
|
||||
@ -473,6 +477,7 @@ class ModelRunner:
|
||||
"context_lens": input_metadata.context_lens,
|
||||
"block_tables": input_metadata.block_tables,
|
||||
"use_cuda_graph": input_metadata.use_cuda_graph,
|
||||
"kv_cache_dtype": input_metadata.kv_cache_dtype,
|
||||
"selected_token_indices":
|
||||
sampling_metadata.selected_token_indices,
|
||||
"lora_requests": lora_requests,
|
||||
@ -495,6 +500,7 @@ class ModelRunner:
|
||||
context_lens=metadata_dict["context_lens"],
|
||||
block_tables=metadata_dict["block_tables"],
|
||||
use_cuda_graph=metadata_dict["use_cuda_graph"],
|
||||
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
|
||||
)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
@ -665,6 +671,7 @@ class ModelRunner:
|
||||
context_lens=context_lens[:batch_size],
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
|
@ -37,6 +37,7 @@ class Worker:
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
@ -54,6 +55,7 @@ class Worker:
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# self.init_cache_engine().
|
||||
@ -95,6 +97,7 @@ class Worker:
|
||||
block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
cpu_swap_space: int,
|
||||
cache_dtype: str,
|
||||
) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model and returns the maximum
|
||||
number of GPU and CPU cache blocks that can be allocated.
|
||||
@ -119,7 +122,7 @@ class Worker:
|
||||
peak_memory = total_gpu_memory - free_gpu_memory
|
||||
|
||||
cache_block_size = CacheEngine.get_cache_block_size(
|
||||
block_size, self.model_config, self.parallel_config)
|
||||
block_size, cache_dtype, self.model_config, self.parallel_config)
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
|
||||
cache_block_size)
|
||||
|
Loading…
x
Reference in New Issue
Block a user