[Kernel] Re-tune Mixtral MoE configurations for FP8 on H100 (#5238)
This commit is contained in:
parent
eb8fcd2666
commit
51a08e7d8f
@ -255,7 +255,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
if args.batch_size is None:
|
if args.batch_size is None:
|
||||||
batch_sizes = [
|
batch_sizes = [
|
||||||
1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048, 3072, 4096
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||||
|
2048, 3072, 4096
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
batch_sizes = [args.batch_size]
|
batch_sizes = [args.batch_size]
|
||||||
|
@ -1,113 +1,113 @@
|
|||||||
{
|
{
|
||||||
"1": {
|
"1": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 32,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 64,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 8,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 5
|
||||||
},
|
},
|
||||||
"2": {
|
"2": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 64,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 8,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"4": {
|
"4": {
|
||||||
"BLOCK_SIZE_M": 256,
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2
|
"num_stages": 3
|
||||||
},
|
|
||||||
"8": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 256,
|
|
||||||
"BLOCK_SIZE_K": 64,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 8,
|
|
||||||
"num_stages": 2
|
|
||||||
},
|
|
||||||
"16": {
|
|
||||||
"BLOCK_SIZE_M": 256,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 64,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 8,
|
|
||||||
"num_stages": 5
|
|
||||||
},
|
},
|
||||||
"24": {
|
"24": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 64,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"32": {
|
"32": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 64,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"48": {
|
"48": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 64,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"64": {
|
"64": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 256,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 2
|
|
||||||
},
|
|
||||||
"96": {
|
|
||||||
"BLOCK_SIZE_M": 128,
|
|
||||||
"BLOCK_SIZE_N": 32,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 5
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"128": {
|
"128": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 8,
|
"num_warps": 4,
|
||||||
"num_stages": 5
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"256": {
|
"256": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 8,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"512": {
|
"512": {
|
||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 128,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 256,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 16,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 8,
|
"num_warps": 8,
|
||||||
"num_stages": 3
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"1024": {
|
"1024": {
|
||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 128,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 256,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 8,
|
"num_warps": 8,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
@ -115,7 +115,7 @@
|
|||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 128,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 256,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 8,
|
"num_warps": 8,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
|
@ -1,53 +1,5 @@
|
|||||||
{
|
{
|
||||||
"1": {
|
"1": {
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 32,
|
|
||||||
"BLOCK_SIZE_K": 64,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 5
|
|
||||||
},
|
|
||||||
"2": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 8,
|
|
||||||
"num_stages": 4
|
|
||||||
},
|
|
||||||
"4": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 256,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 2
|
|
||||||
},
|
|
||||||
"8": {
|
|
||||||
"BLOCK_SIZE_M": 128,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"16": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 256,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 4
|
|
||||||
},
|
|
||||||
"24": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 5
|
|
||||||
},
|
|
||||||
"32": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
@ -55,7 +7,39 @@
|
|||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"48": {
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
@ -63,6 +47,22 @@
|
|||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
"64": {
|
"64": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
@ -81,19 +81,19 @@
|
|||||||
},
|
},
|
||||||
"128": {
|
"128": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"256": {
|
"256": {
|
||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 64,
|
||||||
"num_warps": 8,
|
"num_warps": 4,
|
||||||
"num_stages": 5
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"512": {
|
"512": {
|
||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 128,
|
||||||
@ -107,7 +107,7 @@
|
|||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 128,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 256,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 8,
|
"num_warps": 8,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user