[Model] Support MAP-NEO model (#5081)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
533c217792
commit
a22dea54d3
@ -170,7 +170,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||||
parser.add_argument("--head-size",
|
parser.add_argument("--head-size",
|
||||||
type=int,
|
type=int,
|
||||||
choices=[64, 80, 96, 112, 128, 256],
|
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||||
default=128)
|
default=128)
|
||||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||||
parser.add_argument("--use-alibi", action="store_true")
|
parser.add_argument("--use-alibi", action="store_true")
|
||||||
|
@ -93,7 +93,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--num-heads", type=int, default=8)
|
parser.add_argument("--num-heads", type=int, default=8)
|
||||||
parser.add_argument("--head-size",
|
parser.add_argument("--head-size",
|
||||||
type=int,
|
type=int,
|
||||||
choices=[64, 80, 96, 112, 128, 256],
|
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||||
default=128)
|
default=128)
|
||||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||||
parser.add_argument("--dtype",
|
parser.add_argument("--dtype",
|
||||||
|
@ -754,6 +754,9 @@ void paged_attention_v1_launcher(
|
|||||||
case 128:
|
case 128:
|
||||||
LAUNCH_PAGED_ATTENTION_V1(128);
|
LAUNCH_PAGED_ATTENTION_V1(128);
|
||||||
break;
|
break;
|
||||||
|
case 192:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V1(192);
|
||||||
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
LAUNCH_PAGED_ATTENTION_V1(256);
|
LAUNCH_PAGED_ATTENTION_V1(256);
|
||||||
break;
|
break;
|
||||||
@ -911,6 +914,9 @@ void paged_attention_v2_launcher(
|
|||||||
case 128:
|
case 128:
|
||||||
LAUNCH_PAGED_ATTENTION_V2(128);
|
LAUNCH_PAGED_ATTENTION_V2(128);
|
||||||
break;
|
break;
|
||||||
|
case 192:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V2(192);
|
||||||
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
LAUNCH_PAGED_ATTENTION_V2(256);
|
LAUNCH_PAGED_ATTENTION_V2(256);
|
||||||
break;
|
break;
|
||||||
|
@ -390,6 +390,9 @@ void paged_attention_v1_impl_launcher(
|
|||||||
case 128:
|
case 128:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
|
case 192:
|
||||||
|
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
@ -703,6 +706,9 @@ void paged_attention_v2_impl_launcher(
|
|||||||
case 128:
|
case 128:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
|
case 192:
|
||||||
|
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
|
@ -28,7 +28,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
|||||||
|
|
||||||
# FlashAttention forward only supports head dimension at most 128
|
# FlashAttention forward only supports head dimension at most 128
|
||||||
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256
|
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256
|
||||||
] if not is_hip() else [64, 80, 96, 112, 128]
|
] if not is_hip() else [64, 80, 96, 112, 128]
|
||||||
|
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
|
@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
|
|||||||
NUM_TOKENS = [42] # Arbitrary values for testing
|
NUM_TOKENS = [42] # Arbitrary values for testing
|
||||||
NUM_LAYERS = [1] # Arbitrary values for testing
|
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||||
NUM_HEADS = [8] # Arbitrary values for testing
|
NUM_HEADS = [8] # Arbitrary values for testing
|
||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
|
||||||
BLOCK_SIZES = [8, 16, 32]
|
BLOCK_SIZES = [8, 16, 32]
|
||||||
|
|
||||||
# Arbitrary values for testing
|
# Arbitrary values for testing
|
||||||
|
@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
|
|||||||
|
|
||||||
IS_NEOX_STYLE = [True, False]
|
IS_NEOX_STYLE = [True, False]
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
|
||||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||||
NUM_HEADS = [7, 17] # Arbitrary values for testing
|
NUM_HEADS = [7, 17] # Arbitrary values for testing
|
||||||
BATCH_SIZES = [1, 5] # Arbitrary values for testing
|
BATCH_SIZES = [1, 5] # Arbitrary values for testing
|
||||||
|
@ -31,7 +31,7 @@ class PagedAttention:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_head_sizes() -> List[int]:
|
def get_supported_head_sizes() -> List[int]:
|
||||||
return [64, 80, 96, 112, 128, 256]
|
return [64, 80, 96, 112, 128, 192, 256]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user