Support FP32 (#141)
This commit is contained in:
parent
376725ce74
commit
e38074b1e6
@ -164,7 +164,7 @@ def _get_and_verify_dtype(
|
||||
config_dtype = torch.float32
|
||||
|
||||
dtype = dtype.lower()
|
||||
if dtype == "default":
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32 models.
|
||||
torch_dtype = torch.float16
|
||||
@ -184,9 +184,8 @@ def _get_and_verify_dtype(
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is not allowed.
|
||||
raise ValueError(
|
||||
f"Cannot use {torch_dtype} for {config_dtype} model.")
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warn(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
|
@ -28,9 +28,10 @@ class LLM:
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
we support `float16` and `bfloat16`. If `default`, we use the
|
||||
`torch_dtype` attribute of the model config. If the `torch_dtype`
|
||||
is `float32`, we use `float16` instead.
|
||||
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
||||
the `torch_dtype` attribute specified in the model config file.
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
"""
|
||||
|
||||
@ -38,7 +39,7 @@ class LLM:
|
||||
self,
|
||||
model: str,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "default",
|
||||
dtype: str = "auto",
|
||||
seed: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
@ -10,7 +10,7 @@ from cacheflow import cache_ops
|
||||
from cacheflow import pos_encoding_ops
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
|
||||
|
||||
|
||||
class GPTCacheFlowAttention(nn.Module):
|
||||
@ -49,10 +49,8 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||
raise ValueError(f'head_size ({self.head_size}) is not supported by '
|
||||
'the single_query_cached_kv_attention kernel. '
|
||||
'Use one of the following head sizes: '
|
||||
f'{_SUPPORTED_HEAD_SIZES}.')
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
|
@ -13,7 +13,7 @@ class ServerArgs:
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
use_dummy_weights: bool = False
|
||||
dtype: str = "default"
|
||||
dtype: str = "auto"
|
||||
seed: int = 0
|
||||
worker_use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
@ -49,9 +49,9 @@ class ServerArgs:
|
||||
help='use dummy values for model weights')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
||||
choices=['default', 'half', 'bfloat16'],
|
||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "default" option will use FP16 precision '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
# Parallel arguments
|
||||
@ -67,7 +67,7 @@ class ServerArgs:
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int,
|
||||
default=ServerArgs.block_size,
|
||||
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
choices=[8, 16, 32],
|
||||
help='token block size')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
||||
|
@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher(
|
||||
dim3 block(NUM_THREADS);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
case 32:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
|
||||
// 32, 160, 192, 256.
|
||||
// case 32:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
case 64:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher(
|
||||
case 128:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
case 192:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
// case 160:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// case 192:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// case 256:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher(
|
||||
context_lens, \
|
||||
max_context_len);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 1: \
|
||||
CALL_KERNEL_LAUNCHER(T, 1); \
|
||||
break; \
|
||||
case 2: \
|
||||
CALL_KERNEL_LAUNCHER(T, 2); \
|
||||
break; \
|
||||
case 4: \
|
||||
CALL_KERNEL_LAUNCHER(T, 4); \
|
||||
break; \
|
||||
/* case 1: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 1); */ \
|
||||
/* break; */ \
|
||||
/* case 2: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 2); */ \
|
||||
/* break; */ \
|
||||
/* case 4: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 4); */ \
|
||||
/* break; */ \
|
||||
case 8: \
|
||||
CALL_KERNEL_LAUNCHER(T, 8); \
|
||||
break; \
|
||||
@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher(
|
||||
case 32: \
|
||||
CALL_KERNEL_LAUNCHER(T, 32); \
|
||||
break; \
|
||||
case 64: \
|
||||
CALL_KERNEL_LAUNCHER(T, 64); \
|
||||
break; \
|
||||
case 128: \
|
||||
CALL_KERNEL_LAUNCHER(T, 128); \
|
||||
break; \
|
||||
case 256: \
|
||||
CALL_KERNEL_LAUNCHER(T, 256); \
|
||||
break; \
|
||||
/* case 64: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 64); */ \
|
||||
/* break; */ \
|
||||
/* case 128: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 128); */ \
|
||||
/* break; */ \
|
||||
/* case 256: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 256); */ \
|
||||
/* break; */ \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
@ -455,8 +459,9 @@ void single_query_cached_kv_attention(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len) {
|
||||
// TODO(woosuk): Support FP32.
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
|
@ -18,8 +18,11 @@ CacheFlow can run on systems that meet the following requirements:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Pull the Docker image with CUDA 11.8.
|
||||
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
|
||||
|
||||
Inside the Docker container, please execute :code:`pip uninstall torch` before installing CacheFlow.
|
||||
|
||||
Install with pip
|
||||
----------------
|
||||
|
||||
|
5
setup.py
5
setup.py
@ -66,6 +66,11 @@ if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
||||
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
num_threads = min(os.cpu_count(), 8)
|
||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||
|
||||
ext_modules = []
|
||||
|
||||
# Cache operations.
|
||||
|
@ -270,9 +270,9 @@ def run_multi_query_kv_attention(
|
||||
def test_single_query_cached_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for block_size in [8, 16, 32, 64]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
f'head_size={head_size}')
|
||||
@ -289,8 +289,8 @@ def test_single_query_cached_kv_attention() -> None:
|
||||
def test_multi_query_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
||||
f'head_size={head_size}')
|
||||
run_multi_query_kv_attention(
|
||||
|
Loading…
x
Reference in New Issue
Block a user