From e38074b1e6ad0975acbfa15d858c4bd7cd005e99 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 7 Jun 2023 00:40:21 -0700 Subject: [PATCH] Support FP32 (#141) --- cacheflow/config.py | 7 +- cacheflow/entrypoints/llm.py | 9 +-- cacheflow/model_executor/layers/attention.py | 8 +-- cacheflow/server/arg_utils.py | 8 +-- csrc/attention/attention_kernels.cu | 69 +++++++++++--------- docs/source/getting_started/installation.rst | 3 + setup.py | 5 ++ tests/kernels/test_attention.py | 10 +-- 8 files changed, 65 insertions(+), 54 deletions(-) diff --git a/cacheflow/config.py b/cacheflow/config.py index cf779723..e182e01c 100644 --- a/cacheflow/config.py +++ b/cacheflow/config.py @@ -164,7 +164,7 @@ def _get_and_verify_dtype( config_dtype = torch.float32 dtype = dtype.lower() - if dtype == "default": + if dtype == "auto": if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 models. torch_dtype = torch.float16 @@ -184,9 +184,8 @@ def _get_and_verify_dtype( # Downcasting from float32 to float16 or bfloat16 is allowed. pass else: - # Casting between float16 and bfloat16 is not allowed. - raise ValueError( - f"Cannot use {torch_dtype} for {config_dtype} model.") + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warn(f"Casting {config_dtype} to {torch_dtype}.") # Check if the GPU supports the dtype. if torch_dtype == torch.bfloat16: diff --git a/cacheflow/entrypoints/llm.py b/cacheflow/entrypoints/llm.py index 75a92fa9..f61e16b2 100644 --- a/cacheflow/entrypoints/llm.py +++ b/cacheflow/entrypoints/llm.py @@ -28,9 +28,10 @@ class LLM: tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. dtype: The data type for the model weights and activations. Currently, - we support `float16` and `bfloat16`. If `default`, we use the - `torch_dtype` attribute of the model config. If the `torch_dtype` - is `float32`, we use `float16` instead. + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. seed: The seed to initialize the random number generator for sampling. """ @@ -38,7 +39,7 @@ class LLM: self, model: str, tensor_parallel_size: int = 1, - dtype: str = "default", + dtype: str = "auto", seed: int = 0, **kwargs, ) -> None: diff --git a/cacheflow/model_executor/layers/attention.py b/cacheflow/model_executor/layers/attention.py index 0231ee9a..67ea3987 100644 --- a/cacheflow/model_executor/layers/attention.py +++ b/cacheflow/model_executor/layers/attention.py @@ -10,7 +10,7 @@ from cacheflow import cache_ops from cacheflow import pos_encoding_ops from cacheflow.model_executor.input_metadata import InputMetadata -_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256] +_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128] class GPTCacheFlowAttention(nn.Module): @@ -49,10 +49,8 @@ class GPTCacheFlowAttention(nn.Module): self.attn_op = xops.fmha.cutlass.FwOp() if self.head_size not in _SUPPORTED_HEAD_SIZES: - raise ValueError(f'head_size ({self.head_size}) is not supported by ' - 'the single_query_cached_kv_attention kernel. ' - 'Use one of the following head sizes: ' - f'{_SUPPORTED_HEAD_SIZES}.') + raise ValueError(f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") def multi_query_kv_attention( self, diff --git a/cacheflow/server/arg_utils.py b/cacheflow/server/arg_utils.py index 63f32c80..e70e5979 100644 --- a/cacheflow/server/arg_utils.py +++ b/cacheflow/server/arg_utils.py @@ -13,7 +13,7 @@ class ServerArgs: download_dir: Optional[str] = None use_np_weights: bool = False use_dummy_weights: bool = False - dtype: str = "default" + dtype: str = "auto" seed: int = 0 worker_use_ray: bool = False pipeline_parallel_size: int = 1 @@ -49,9 +49,9 @@ class ServerArgs: help='use dummy values for model weights') # TODO(woosuk): Support FP32. parser.add_argument('--dtype', type=str, default=ServerArgs.dtype, - choices=['default', 'half', 'bfloat16'], + choices=['auto', 'half', 'bfloat16', 'float'], help='data type for model weights and activations. ' - 'The "default" option will use FP16 precision ' + 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') # Parallel arguments @@ -67,7 +67,7 @@ class ServerArgs: # KV cache arguments parser.add_argument('--block-size', type=int, default=ServerArgs.block_size, - choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], + choices=[8, 16, 32], help='token block size') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=ServerArgs.seed, diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 0854f343..8d86605a 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher( dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { - case 32: - LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); - break; + // NOTE(woosuk): To reduce the compilation time, we omitted head sizes + // 32, 160, 192, 256. + // case 32: + // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + // break; case 64: LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); break; @@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher( case 128: LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); break; - case 160: - LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); - break; - case 192: - LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); - break; - case 256: - LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); - break; + // case 160: + // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + // break; + // case 192: + // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + // break; + // case 256: + // LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + // break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); break; @@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher( context_lens, \ max_context_len); +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. #define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ - case 1: \ - CALL_KERNEL_LAUNCHER(T, 1); \ - break; \ - case 2: \ - CALL_KERNEL_LAUNCHER(T, 2); \ - break; \ - case 4: \ - CALL_KERNEL_LAUNCHER(T, 4); \ - break; \ + /* case 1: */ \ + /* CALL_KERNEL_LAUNCHER(T, 1); */ \ + /* break; */ \ + /* case 2: */ \ + /* CALL_KERNEL_LAUNCHER(T, 2); */ \ + /* break; */ \ + /* case 4: */ \ + /* CALL_KERNEL_LAUNCHER(T, 4); */ \ + /* break; */ \ case 8: \ CALL_KERNEL_LAUNCHER(T, 8); \ break; \ @@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher( case 32: \ CALL_KERNEL_LAUNCHER(T, 32); \ break; \ - case 64: \ - CALL_KERNEL_LAUNCHER(T, 64); \ - break; \ - case 128: \ - CALL_KERNEL_LAUNCHER(T, 128); \ - break; \ - case 256: \ - CALL_KERNEL_LAUNCHER(T, 256); \ - break; \ + /* case 64: */ \ + /* CALL_KERNEL_LAUNCHER(T, 64); */ \ + /* break; */ \ + /* case 128: */ \ + /* CALL_KERNEL_LAUNCHER(T, 128); */ \ + /* break; */ \ + /* case 256: */ \ + /* CALL_KERNEL_LAUNCHER(T, 256); */ \ + /* break; */ \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ @@ -455,8 +459,9 @@ void single_query_cached_kv_attention( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len) { - // TODO(woosuk): Support FP32. - if (query.dtype() == at::ScalarType::Half) { + if (query.dtype() == at::ScalarType::Float) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); } else if (query.dtype() == at::ScalarType::BFloat16) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 326e6960..b78d993e 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -18,8 +18,11 @@ CacheFlow can run on systems that meet the following requirements: .. code-block:: console + $ # Pull the Docker image with CUDA 11.8. $ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3 + Inside the Docker container, please execute :code:`pip uninstall torch` before installing CacheFlow. + Install with pip ---------------- diff --git a/setup.py b/setup.py index 80134e14..3ddc3641 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,11 @@ if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): raise RuntimeError( "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") +# Use NVCC threads to parallelize the build. +if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] + ext_modules = [] # Cache operations. diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 49c745a5..c0db942d 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -270,9 +270,9 @@ def run_multi_query_kv_attention( def test_single_query_cached_kv_attention() -> None: torch.random.manual_seed(TEST_SEED) torch.cuda.manual_seed(TEST_SEED) - for dtype in [torch.half, torch.bfloat16]: - for block_size in [8, 16, 32, 64]: - for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: + for dtype in [torch.half, torch.bfloat16, torch.float]: + for block_size in [8, 16, 32]: + for head_size in [64, 80, 96, 128]: print(f'Testing single_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' f'head_size={head_size}') @@ -289,8 +289,8 @@ def test_single_query_cached_kv_attention() -> None: def test_multi_query_kv_attention() -> None: torch.random.manual_seed(TEST_SEED) torch.cuda.manual_seed(TEST_SEED) - for dtype in [torch.half, torch.bfloat16]: - for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: + for dtype in [torch.half, torch.bfloat16, torch.float]: + for head_size in [64, 80, 96, 128]: print(f'Testing multi_query_kv_attention with dtype={dtype}, ' f'head_size={head_size}') run_multi_query_kv_attention(