Support FP32 (#141)

This commit is contained in:
Woosuk Kwon 2023-06-07 00:40:21 -07:00 committed by GitHub
parent 376725ce74
commit e38074b1e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 65 additions and 54 deletions

View File

@ -164,7 +164,7 @@ def _get_and_verify_dtype(
config_dtype = torch.float32 config_dtype = torch.float32
dtype = dtype.lower() dtype = dtype.lower()
if dtype == "default": if dtype == "auto":
if config_dtype == torch.float32: if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models. # Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16 torch_dtype = torch.float16
@ -184,9 +184,8 @@ def _get_and_verify_dtype(
# Downcasting from float32 to float16 or bfloat16 is allowed. # Downcasting from float32 to float16 or bfloat16 is allowed.
pass pass
else: else:
# Casting between float16 and bfloat16 is not allowed. # Casting between float16 and bfloat16 is allowed with a warning.
raise ValueError( logger.warn(f"Casting {config_dtype} to {torch_dtype}.")
f"Cannot use {torch_dtype} for {config_dtype} model.")
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: if torch_dtype == torch.bfloat16:

View File

@ -28,9 +28,10 @@ class LLM:
tensor_parallel_size: The number of GPUs to use for distributed tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism. execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently, dtype: The data type for the model weights and activations. Currently,
we support `float16` and `bfloat16`. If `default`, we use the we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
`torch_dtype` attribute of the model config. If the `torch_dtype` the `torch_dtype` attribute specified in the model config file.
is `float32`, we use `float16` instead. However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
seed: The seed to initialize the random number generator for sampling. seed: The seed to initialize the random number generator for sampling.
""" """
@ -38,7 +39,7 @@ class LLM:
self, self,
model: str, model: str,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: str = "default", dtype: str = "auto",
seed: int = 0, seed: int = 0,
**kwargs, **kwargs,
) -> None: ) -> None:

View File

@ -10,7 +10,7 @@ from cacheflow import cache_ops
from cacheflow import pos_encoding_ops from cacheflow import pos_encoding_ops
from cacheflow.model_executor.input_metadata import InputMetadata from cacheflow.model_executor.input_metadata import InputMetadata
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
class GPTCacheFlowAttention(nn.Module): class GPTCacheFlowAttention(nn.Module):
@ -49,10 +49,8 @@ class GPTCacheFlowAttention(nn.Module):
self.attn_op = xops.fmha.cutlass.FwOp() self.attn_op = xops.fmha.cutlass.FwOp()
if self.head_size not in _SUPPORTED_HEAD_SIZES: if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f'head_size ({self.head_size}) is not supported by ' raise ValueError(f"head_size ({self.head_size}) is not supported. "
'the single_query_cached_kv_attention kernel. ' f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
'Use one of the following head sizes: '
f'{_SUPPORTED_HEAD_SIZES}.')
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,

View File

@ -13,7 +13,7 @@ class ServerArgs:
download_dir: Optional[str] = None download_dir: Optional[str] = None
use_np_weights: bool = False use_np_weights: bool = False
use_dummy_weights: bool = False use_dummy_weights: bool = False
dtype: str = "default" dtype: str = "auto"
seed: int = 0 seed: int = 0
worker_use_ray: bool = False worker_use_ray: bool = False
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
@ -49,9 +49,9 @@ class ServerArgs:
help='use dummy values for model weights') help='use dummy values for model weights')
# TODO(woosuk): Support FP32. # TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype, parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
choices=['default', 'half', 'bfloat16'], choices=['auto', 'half', 'bfloat16', 'float'],
help='data type for model weights and activations. ' help='data type for model weights and activations. '
'The "default" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
# Parallel arguments # Parallel arguments
@ -67,7 +67,7 @@ class ServerArgs:
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, parser.add_argument('--block-size', type=int,
default=ServerArgs.block_size, default=ServerArgs.block_size,
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], choices=[8, 16, 32],
help='token block size') help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=ServerArgs.seed, parser.add_argument('--seed', type=int, default=ServerArgs.seed,

View File

@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher(
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) { switch (head_size) {
case 32: // NOTE(woosuk): To reduce the compilation time, we omitted head sizes
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); // 32, 160, 192, 256.
break; // case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
case 64: case 64:
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
break; break;
@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher(
case 128: case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break; break;
case 160: // case 160:
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
break; // break;
case 192: // case 192:
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
break; // break;
case 256: // case 256:
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); // LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break; // break;
default: default:
TORCH_CHECK(false, "Unsupported head size: ", head_size); TORCH_CHECK(false, "Unsupported head size: ", head_size);
break; break;
@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher(
context_lens, \ context_lens, \
max_context_len); max_context_len);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
case 1: \ /* case 1: */ \
CALL_KERNEL_LAUNCHER(T, 1); \ /* CALL_KERNEL_LAUNCHER(T, 1); */ \
break; \ /* break; */ \
case 2: \ /* case 2: */ \
CALL_KERNEL_LAUNCHER(T, 2); \ /* CALL_KERNEL_LAUNCHER(T, 2); */ \
break; \ /* break; */ \
case 4: \ /* case 4: */ \
CALL_KERNEL_LAUNCHER(T, 4); \ /* CALL_KERNEL_LAUNCHER(T, 4); */ \
break; \ /* break; */ \
case 8: \ case 8: \
CALL_KERNEL_LAUNCHER(T, 8); \ CALL_KERNEL_LAUNCHER(T, 8); \
break; \ break; \
@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher(
case 32: \ case 32: \
CALL_KERNEL_LAUNCHER(T, 32); \ CALL_KERNEL_LAUNCHER(T, 32); \
break; \ break; \
case 64: \ /* case 64: */ \
CALL_KERNEL_LAUNCHER(T, 64); \ /* CALL_KERNEL_LAUNCHER(T, 64); */ \
break; \ /* break; */ \
case 128: \ /* case 128: */ \
CALL_KERNEL_LAUNCHER(T, 128); \ /* CALL_KERNEL_LAUNCHER(T, 128); */ \
break; \ /* break; */ \
case 256: \ /* case 256: */ \
CALL_KERNEL_LAUNCHER(T, 256); \ /* CALL_KERNEL_LAUNCHER(T, 256); */ \
break; \ /* break; */ \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
@ -455,8 +459,9 @@ void single_query_cached_kv_attention(
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& context_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len) { int max_context_len) {
// TODO(woosuk): Support FP32. if (query.dtype() == at::ScalarType::Float) {
if (query.dtype() == at::ScalarType::Half) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
} else if (query.dtype() == at::ScalarType::BFloat16) { } else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);

View File

@ -18,8 +18,11 @@ CacheFlow can run on systems that meet the following requirements:
.. code-block:: console .. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3 $ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
Inside the Docker container, please execute :code:`pip uninstall torch` before installing CacheFlow.
Install with pip Install with pip
---------------- ----------------

View File

@ -66,6 +66,11 @@ if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
num_threads = min(os.cpu_count(), 8)
NVCC_FLAGS += ["--threads", str(num_threads)]
ext_modules = [] ext_modules = []
# Cache operations. # Cache operations.

View File

@ -270,9 +270,9 @@ def run_multi_query_kv_attention(
def test_single_query_cached_kv_attention() -> None: def test_single_query_cached_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED) torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED) torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16]: for dtype in [torch.half, torch.bfloat16, torch.float]:
for block_size in [8, 16, 32, 64]: for block_size in [8, 16, 32]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: for head_size in [64, 80, 96, 128]:
print(f'Testing single_query_cached_kv_attention with ' print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, ' f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}') f'head_size={head_size}')
@ -289,8 +289,8 @@ def test_single_query_cached_kv_attention() -> None:
def test_multi_query_kv_attention() -> None: def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED) torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED) torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16]: for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: for head_size in [64, 80, 96, 128]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, ' print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}') f'head_size={head_size}')
run_multi_query_kv_attention( run_multi_query_kv_attention(