2024-08-20 09:09:33 -04:00
|
|
|
import itertools
|
|
|
|
import math
|
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
from collections.abc import Iterable
|
2024-11-18 14:59:29 -05:00
|
|
|
from copy import deepcopy
|
|
|
|
from dataclasses import dataclass, fields
|
|
|
|
from functools import reduce
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
2024-08-20 09:09:33 -04:00
|
|
|
|
|
|
|
import jinja2
|
|
|
|
# yapf conflicts with isort for this block
|
|
|
|
# yapf: disable
|
|
|
|
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
|
|
|
|
EpilogueScheduleType,
|
|
|
|
MixedInputKernelScheduleType,
|
|
|
|
TileSchedulerTag,
|
|
|
|
TileSchedulerType, VLLMDataType,
|
2024-11-18 14:59:29 -05:00
|
|
|
VLLMDataTypeNames,
|
|
|
|
VLLMDataTypeSize, VLLMDataTypeTag,
|
|
|
|
VLLMDataTypeTorchDataTypeTag,
|
|
|
|
VLLMDataTypeVLLMScalarTypeTag,
|
2024-08-20 09:09:33 -04:00
|
|
|
VLLMKernelScheduleTag)
|
|
|
|
|
|
|
|
# yapf: enable
|
|
|
|
|
|
|
|
#
|
|
|
|
# Generator templating
|
|
|
|
#
|
|
|
|
|
|
|
|
DISPATCH_TEMPLATE = """
|
|
|
|
#include "../machete_mm_launcher.cuh"
|
|
|
|
|
|
|
|
namespace machete {
|
2024-11-18 14:59:29 -05:00
|
|
|
|
|
|
|
{% for impl_config in impl_configs %}
|
|
|
|
{% set type_sig = gen_type_sig(impl_config.types) -%}
|
|
|
|
{% for s in impl_config.schedules %}
|
|
|
|
extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
|
|
|
|
{%- endfor %}
|
|
|
|
|
|
|
|
torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
2024-08-20 09:09:33 -04:00
|
|
|
[[maybe_unused]] auto M = args.A.size(0);
|
|
|
|
[[maybe_unused]] auto N = args.B.size(1);
|
|
|
|
[[maybe_unused]] auto K = args.A.size(1);
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
if (!args.maybe_schedule) {
|
|
|
|
{%- for cond, s in impl_config.heuristic %}
|
2024-08-20 09:09:33 -04:00
|
|
|
{%if cond is not none%}if ({{cond}})
|
|
|
|
{%- else %}else
|
|
|
|
{%- endif %}
|
2024-11-18 14:59:29 -05:00
|
|
|
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
{%- for s in impl_config.schedules %}
|
|
|
|
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
|
|
|
|
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
|
|
|
|
{%- endfor %}
|
2024-08-20 09:09:33 -04:00
|
|
|
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
|
2024-11-18 14:59:29 -05:00
|
|
|
"schedule = ", *args.maybe_schedule);
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
2024-11-18 14:59:29 -05:00
|
|
|
{%- endfor %}
|
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
static inline std::optional<at::ScalarType> maybe_scalartype(
|
2025-01-04 17:20:34 -08:00
|
|
|
std::optional<at::Tensor> const& t) {
|
2024-11-18 14:59:29 -05:00
|
|
|
if (!t) {
|
|
|
|
return std::nullopt;
|
|
|
|
} else {
|
|
|
|
return t->scalar_type();
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
torch::Tensor mm_dispatch(MMArgs args) {
|
|
|
|
auto out_type = args.maybe_out_type.value_or(args.A.scalar_type());
|
|
|
|
auto a_type = args.A.scalar_type();
|
|
|
|
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
|
|
|
|
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
|
|
|
|
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
|
|
|
|
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
|
|
|
|
|
|
|
|
{% for impl_config in impl_configs %}
|
|
|
|
{% set t = impl_config.types -%}
|
|
|
|
{% set type_sig = gen_type_sig(t) -%}
|
|
|
|
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
|
|
|
&& a_type == {{TorchTypeTag[t.a]}}
|
|
|
|
&& out_type == {{TorchTypeTag[t.out]}}
|
|
|
|
&& {%if t.b_group_scale != void -%}
|
|
|
|
maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
|
|
|
{%- else %}!maybe_g_scales_type{%endif%}
|
|
|
|
&& {%if t.b_group_zeropoint != void -%}
|
|
|
|
maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
|
|
|
{%- else %}!maybe_g_zeros_type{%endif%}
|
|
|
|
&& {%if t.b_channel_scale != void -%}
|
|
|
|
maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}}
|
|
|
|
{%- else %}!maybe_ch_scales_type{%endif%}
|
|
|
|
&& {%if t.a_token_scale != void -%}
|
|
|
|
maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}}
|
|
|
|
{%- else %}!maybe_tok_scales_type{%endif%}
|
|
|
|
) {
|
|
|
|
return mm_dispatch_{{type_sig}}(args);
|
|
|
|
}
|
|
|
|
{%- endfor %}
|
|
|
|
|
|
|
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
|
|
false, "machete_mm(..) is not implemented for "
|
|
|
|
"a_type=", args.A.scalar_type(),
|
|
|
|
", b_type=", args.b_type.str(),
|
|
|
|
", out_type=", out_type,
|
|
|
|
", with_group_scale_type=", maybe_g_scales_type
|
|
|
|
? toString(*maybe_g_scales_type) : "None",
|
|
|
|
", with_group_zeropoint_type=", maybe_g_zeros_type
|
|
|
|
? toString(*maybe_g_zeros_type) : "None",
|
|
|
|
", with_channel_scale_type=", maybe_ch_scales_type
|
|
|
|
? toString(*maybe_ch_scales_type) : "None",
|
|
|
|
", with_token_scale_type=", maybe_tok_scales_type
|
|
|
|
? toString(*maybe_tok_scales_type) : "None",
|
|
|
|
"; implemented types are: \\n",
|
|
|
|
{%- for impl_config in impl_configs %}
|
|
|
|
{% set t = impl_config.types -%}
|
|
|
|
"\\t{{gen_type_option_name(t)}}\\n",
|
|
|
|
{%- endfor %}
|
|
|
|
"");
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
std::vector<std::string> supported_schedules_dispatch(
|
|
|
|
SupportedSchedulesArgs args) {
|
|
|
|
auto out_type = args.maybe_out_type.value_or(args.a_type);
|
|
|
|
|
|
|
|
{% for impl_config in impl_configs %}
|
|
|
|
{% set t = impl_config.types -%}
|
|
|
|
{% set schs = impl_config.schedules -%}
|
|
|
|
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
|
|
|
&& args.a_type == {{TorchTypeTag[t.a]}}
|
|
|
|
&& out_type == {{TorchTypeTag[t.out]}}
|
|
|
|
&& {%if t.b_group_scale != void -%}
|
|
|
|
args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
|
|
|
{%- else %}!args.maybe_group_scales_type{%endif%}
|
|
|
|
&& {%if t.b_group_zeropoint != void-%}
|
|
|
|
args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
|
|
|
{%- else %}!args.maybe_group_zeros_type{%endif%}
|
|
|
|
) {
|
|
|
|
return {
|
|
|
|
{%- for s in impl_config.schedules %}
|
|
|
|
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
|
|
|
|
{%- endfor %}
|
|
|
|
};
|
|
|
|
}
|
|
|
|
{%- endfor %}
|
|
|
|
|
|
|
|
return {};
|
|
|
|
};
|
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
}; // namespace machete
|
|
|
|
"""
|
|
|
|
|
|
|
|
IMPL_TEMPLATE = """
|
|
|
|
#include "../machete_mm_launcher.cuh"
|
|
|
|
|
|
|
|
namespace machete {
|
2024-11-18 14:59:29 -05:00
|
|
|
|
|
|
|
{% for sch in unique_schedules(impl_configs) %}
|
|
|
|
{% set sch_sig = gen_sch_sig(sch) -%}
|
|
|
|
struct sch_{{sch_sig}} {
|
2024-08-20 09:09:33 -04:00
|
|
|
using TileShapeNM = Shape<{{
|
|
|
|
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
|
|
|
using ClusterShape = Shape<{{
|
|
|
|
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
|
|
|
|
// TODO: Reimplement
|
|
|
|
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
|
|
|
|
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
|
|
|
|
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
|
|
|
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
|
|
|
};
|
2024-11-18 14:59:29 -05:00
|
|
|
{% endfor %}
|
|
|
|
|
|
|
|
{% for impl_config in impl_configs %}
|
|
|
|
{% set t = impl_config.types -%}
|
|
|
|
{% set schs = impl_config.schedules -%}
|
|
|
|
{% set type_sig = gen_type_sig(t) -%}
|
|
|
|
|
|
|
|
template<typename Sch>
|
|
|
|
using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
|
|
|
{{DataTypeTag[t.a]}}, // ElementA
|
|
|
|
{{DataTypeTag[t.b]}}, // ElementB
|
|
|
|
{{DataTypeTag[t.out]}}, // ElementD
|
|
|
|
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
|
|
|
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
|
|
|
|
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
|
|
|
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
|
|
|
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
2024-12-30 04:22:13 -05:00
|
|
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
2024-11-18 14:59:29 -05:00
|
|
|
Sch>;
|
|
|
|
|
|
|
|
{% for sch in schs %}
|
|
|
|
{% set sch_sig = gen_sch_sig(sch) -%}
|
2024-08-20 09:09:33 -04:00
|
|
|
torch::Tensor
|
2024-11-18 14:59:29 -05:00
|
|
|
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
|
|
|
|
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
2024-11-18 14:59:29 -05:00
|
|
|
{%- endfor %}
|
|
|
|
{%- endfor %}
|
2024-08-20 09:09:33 -04:00
|
|
|
|
|
|
|
}; // namespace machete
|
|
|
|
"""
|
|
|
|
|
|
|
|
PREPACK_TEMPLATE = """
|
|
|
|
#include "../machete_prepack_launcher.cuh"
|
|
|
|
|
|
|
|
namespace machete {
|
2024-11-18 14:59:29 -05:00
|
|
|
|
|
|
|
torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
|
|
|
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
|
|
|
|
{%- for t in types %}
|
|
|
|
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
|
|
|
|
if (args.a_type == {{TorchTypeTag[t.a]}}
|
|
|
|
&& args.b_type.size_bits() == {{t.b_num_bits}}
|
|
|
|
&& convert_type == {{TorchTypeTag[t.convert]}}) {
|
|
|
|
return prepack_impl<
|
|
|
|
PrepackedLayoutBTemplate<
|
|
|
|
{{DataTypeTag[t.a]}}, // ElementA
|
|
|
|
{{DataTypeTag[b_type]}}, // ElementB
|
|
|
|
{{DataTypeTag[t.convert]}}, // ElementConvert
|
|
|
|
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
|
|
|
cutlass::layout::ColumnMajor,
|
2024-12-30 04:22:13 -05:00
|
|
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
2024-11-18 14:59:29 -05:00
|
|
|
>(args.B);
|
|
|
|
}
|
|
|
|
{%- endfor %}
|
|
|
|
|
|
|
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
|
|
|
"prepack_B_dispatch(..) is not implemented for "
|
|
|
|
"atype = ", args.a_type,
|
|
|
|
", b_type = ", args.b_type.str(),
|
|
|
|
", with_group_scales_type= ", args.maybe_group_scales_type ?
|
|
|
|
toString(*args.maybe_group_scales_type) : "None");
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
2024-11-18 14:59:29 -05:00
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
}; // namespace machete
|
|
|
|
"""
|
|
|
|
|
2024-12-30 04:22:13 -05:00
|
|
|
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
2024-08-20 09:09:33 -04:00
|
|
|
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
|
|
|
|
|
|
|
|
2024-09-23 13:46:26 -04:00
|
|
|
@dataclass(frozen=True)
|
2024-08-20 09:09:33 -04:00
|
|
|
class ScheduleConfig:
|
|
|
|
tile_shape_mn: Tuple[int, int]
|
|
|
|
cluster_shape_mnk: Tuple[int, int, int]
|
|
|
|
kernel_schedule: MixedInputKernelScheduleType
|
|
|
|
epilogue_schedule: EpilogueScheduleType
|
|
|
|
tile_scheduler: TileSchedulerType
|
|
|
|
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
@dataclass(frozen=True)
|
2024-08-20 09:09:33 -04:00
|
|
|
class TypeConfig:
|
2024-11-18 14:59:29 -05:00
|
|
|
a: DataType
|
|
|
|
b: Union[DataType, VLLMDataType]
|
|
|
|
b_group_scale: DataType
|
|
|
|
b_group_zeropoint: DataType
|
|
|
|
b_channel_scale: DataType
|
|
|
|
a_token_scale: DataType
|
|
|
|
out: DataType
|
2024-08-20 09:09:33 -04:00
|
|
|
accumulator: DataType
|
|
|
|
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
@dataclass(frozen=True)
|
|
|
|
class PrepackTypeConfig:
|
|
|
|
a: DataType
|
|
|
|
b_num_bits: int
|
|
|
|
convert: DataType
|
|
|
|
accumulator: DataType
|
2024-08-20 09:09:33 -04:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ImplConfig:
|
2024-11-18 14:59:29 -05:00
|
|
|
types: TypeConfig
|
|
|
|
schedules: List[ScheduleConfig]
|
2024-08-20 09:09:33 -04:00
|
|
|
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
|
|
|
|
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
2024-08-20 09:09:33 -04:00
|
|
|
tile_shape = (
|
|
|
|
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
|
|
|
)
|
|
|
|
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
|
|
|
|
f"x{schedule_config.cluster_shape_mnk[1]}" +
|
|
|
|
f"x{schedule_config.cluster_shape_mnk[2]}")
|
|
|
|
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
|
|
|
|
.split("::")[-1]
|
|
|
|
epilogue_schedule = EpilogueScheduleTag[
|
|
|
|
schedule_config.epilogue_schedule].split("::")[-1]
|
|
|
|
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
|
|
|
|
.split("::")[-1]
|
|
|
|
|
|
|
|
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
|
|
|
|
f"_{epilogue_schedule}_{tile_scheduler}")
|
|
|
|
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
# mostly unique shorter sch_sig
|
|
|
|
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
2024-08-20 09:09:33 -04:00
|
|
|
kernel_terse_names_replace = {
|
2024-12-30 04:22:13 -05:00
|
|
|
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
2024-08-20 09:09:33 -04:00
|
|
|
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
|
|
|
"StreamKScheduler": "streamK",
|
|
|
|
}
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
sch_sig = generate_sch_sig(schedule_config)
|
2024-08-20 09:09:33 -04:00
|
|
|
for orig, terse in kernel_terse_names_replace.items():
|
2024-11-18 14:59:29 -05:00
|
|
|
sch_sig = sch_sig.replace(orig, terse)
|
|
|
|
return sch_sig
|
2024-08-20 09:09:33 -04:00
|
|
|
|
|
|
|
|
|
|
|
# unique type_name
|
2024-11-18 14:59:29 -05:00
|
|
|
def generate_type_signature(kernel_types: TypeConfig):
|
|
|
|
return str("".join([
|
|
|
|
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
|
|
|
for field in fields(TypeConfig)
|
|
|
|
]))
|
2024-08-20 09:09:33 -04:00
|
|
|
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
def generate_type_option_name(kernel_types: TypeConfig):
|
|
|
|
return ", ".join([
|
|
|
|
f"{field.name.replace('b_', 'with_')+'_type'}=" +
|
|
|
|
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
|
|
|
for field in fields(TypeConfig)
|
|
|
|
])
|
2024-08-20 09:09:33 -04:00
|
|
|
|
|
|
|
|
|
|
|
def is_power_of_two(n):
|
|
|
|
return (n != 0) and (n & (n - 1) == 0)
|
|
|
|
|
|
|
|
|
|
|
|
def to_cute_constant(value: List[int]):
|
|
|
|
|
|
|
|
def _to_cute_constant(value: int):
|
|
|
|
if is_power_of_two(value):
|
|
|
|
return f"_{value}"
|
|
|
|
else:
|
|
|
|
return f"Int<{value}>"
|
|
|
|
|
|
|
|
if isinstance(value, Iterable):
|
|
|
|
return [_to_cute_constant(value) for value in value]
|
|
|
|
else:
|
|
|
|
return _to_cute_constant(value)
|
|
|
|
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
def unique_schedules(impl_configs: List[ImplConfig]):
|
|
|
|
return list(
|
|
|
|
set(sch for impl_config in impl_configs
|
|
|
|
for sch in impl_config.schedules))
|
|
|
|
|
|
|
|
|
|
|
|
def unsigned_type_with_bitwidth(num_bits):
|
|
|
|
return {
|
|
|
|
4: DataType.u4,
|
|
|
|
8: DataType.u8,
|
|
|
|
16: DataType.u16,
|
|
|
|
32: DataType.u32,
|
|
|
|
64: DataType.u64,
|
|
|
|
}[num_bits]
|
|
|
|
|
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
template_globals = {
|
2024-11-18 14:59:29 -05:00
|
|
|
"void": DataType.void,
|
2024-08-20 09:09:33 -04:00
|
|
|
"DataTypeTag": VLLMDataTypeTag,
|
2024-11-18 14:59:29 -05:00
|
|
|
"VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag,
|
|
|
|
"TorchTypeTag": VLLMDataTypeTorchDataTypeTag,
|
2024-08-20 09:09:33 -04:00
|
|
|
"KernelScheduleTag": VLLMKernelScheduleTag,
|
|
|
|
"EpilogueScheduleTag": EpilogueScheduleTag,
|
|
|
|
"TileSchedulerTag": TileSchedulerTag,
|
|
|
|
"to_cute_constant": to_cute_constant,
|
2024-11-18 14:59:29 -05:00
|
|
|
"gen_sch_sig": generate_terse_sch_sig,
|
|
|
|
"gen_type_sig": generate_type_signature,
|
|
|
|
"unique_schedules": unique_schedules,
|
|
|
|
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
|
|
|
"gen_type_option_name": generate_type_option_name
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def create_template(template_str):
|
|
|
|
template = jinja2.Template(template_str)
|
|
|
|
template.globals.update(template_globals)
|
|
|
|
return template
|
|
|
|
|
|
|
|
|
|
|
|
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
|
|
|
|
mm_impl_template = create_template(IMPL_TEMPLATE)
|
|
|
|
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
|
|
|
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
def create_sources(impl_configs: List[ImplConfig], num_impl_files=8):
|
2024-08-20 09:09:33 -04:00
|
|
|
sources = []
|
|
|
|
|
|
|
|
sources.append((
|
2024-11-18 14:59:29 -05:00
|
|
|
"machete_mm_dispatch",
|
|
|
|
mm_dispatch_template.render(impl_configs=impl_configs),
|
2024-08-20 09:09:33 -04:00
|
|
|
))
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
prepack_types = []
|
|
|
|
for impl_config in impl_configs:
|
|
|
|
convert_type = impl_config.types.a \
|
|
|
|
if impl_config.types.b_group_scale == DataType.void \
|
|
|
|
else impl_config.types.b_group_scale
|
|
|
|
prepack_types.append(
|
|
|
|
PrepackTypeConfig(
|
|
|
|
a=impl_config.types.a,
|
|
|
|
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
|
|
|
convert=convert_type,
|
|
|
|
accumulator=impl_config.types.accumulator,
|
|
|
|
))
|
|
|
|
|
|
|
|
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
|
|
|
# For now we we can just use the first accumulator type seen since
|
|
|
|
# the tensor core shapes/layouts don't vary based on accumulator
|
|
|
|
# type so we can generate less code this way
|
|
|
|
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
|
|
|
|
|
|
|
|
unique_prepack_types = []
|
|
|
|
prepack_types_seen = set()
|
|
|
|
for prepack_type in prepack_types:
|
|
|
|
key = prepacked_type_key(prepack_type)
|
|
|
|
if key not in prepack_types_seen:
|
|
|
|
unique_prepack_types.append(prepack_type)
|
|
|
|
prepack_types_seen.add(key)
|
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
sources.append((
|
2024-11-18 14:59:29 -05:00
|
|
|
"machete_prepack",
|
|
|
|
prepack_dispatch_template.render(types=unique_prepack_types, ),
|
2024-08-20 09:09:33 -04:00
|
|
|
))
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
# Split up impls across files
|
|
|
|
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
|
|
|
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
|
|
|
|
|
|
|
files_impls: List[List[ImplConfig]] = [[]]
|
|
|
|
|
|
|
|
curr_num_impls_assigned = 0
|
|
|
|
curr_impl_in_file = 0
|
|
|
|
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
|
|
|
|
|
|
|
|
while curr_num_impls_assigned < num_impls:
|
|
|
|
room_left_in_file = num_impls_per_file - curr_impl_in_file
|
|
|
|
if room_left_in_file == 0:
|
|
|
|
files_impls.append([])
|
|
|
|
room_left_in_file = num_impls_per_file
|
|
|
|
curr_impl_in_file = 0
|
|
|
|
|
|
|
|
curr_ic = curr_impl_configs[-1]
|
|
|
|
if len(curr_ic.schedules) >= room_left_in_file:
|
|
|
|
# Break apart the current impl config
|
|
|
|
tmp_ic = deepcopy(curr_ic)
|
|
|
|
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
|
|
|
|
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
|
|
|
|
files_impls[-1].append(tmp_ic)
|
|
|
|
else:
|
|
|
|
files_impls[-1].append(curr_ic)
|
|
|
|
curr_impl_configs.pop()
|
|
|
|
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
|
|
|
|
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
2024-08-20 09:09:33 -04:00
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
for part, file_impls in enumerate(files_impls):
|
2024-08-20 09:09:33 -04:00
|
|
|
sources.append((
|
2024-11-18 14:59:29 -05:00
|
|
|
f"machete_mm_impl_part{part+1}",
|
|
|
|
mm_impl_template.render(impl_configs=file_impls),
|
2024-08-20 09:09:33 -04:00
|
|
|
))
|
2024-11-18 14:59:29 -05:00
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
return sources
|
|
|
|
|
|
|
|
|
|
|
|
def generate():
|
|
|
|
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
|
|
|
|
# about how this works
|
|
|
|
SCRIPT_DIR = os.path.dirname(__file__)
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
sch_common_params = dict(
|
2024-09-23 13:46:26 -04:00
|
|
|
kernel_schedule=TmaMI,
|
|
|
|
epilogue_schedule=TmaCoop,
|
|
|
|
tile_scheduler=TileSchedulerType.StreamK,
|
|
|
|
)
|
2024-08-20 09:09:33 -04:00
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
|
|
|
default_tile_heuristic_config = {
|
2024-09-23 13:46:26 -04:00
|
|
|
#### M = 257+
|
2024-11-18 14:59:29 -05:00
|
|
|
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
|
|
|
"M > 256": ((128, 256), (2, 1, 1)),
|
2024-09-23 13:46:26 -04:00
|
|
|
#### M = 129-256
|
2024-11-18 14:59:29 -05:00
|
|
|
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
|
|
|
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
|
|
|
"M > 128": ((128, 256), (2, 1, 1)),
|
2024-09-23 13:46:26 -04:00
|
|
|
#### M = 65-128
|
2024-11-18 14:59:29 -05:00
|
|
|
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
|
|
|
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
|
|
|
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
|
|
|
"M > 64": ((128, 128), (2, 1, 1)),
|
2024-09-23 13:46:26 -04:00
|
|
|
#### M = 33-64
|
2024-11-18 14:59:29 -05:00
|
|
|
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
|
|
|
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
|
|
|
"M > 32": ((128, 64), (2, 1, 1)),
|
2024-09-23 13:46:26 -04:00
|
|
|
#### M = 17-32
|
2024-11-18 14:59:29 -05:00
|
|
|
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
|
|
|
"M > 16": ((256, 32), (2, 1, 1)),
|
2024-09-23 13:46:26 -04:00
|
|
|
#### M = 1-16
|
2024-11-18 14:59:29 -05:00
|
|
|
"N >= 26624": ((256, 16), (1, 1, 1)),
|
|
|
|
None: ((128, 16), (1, 1, 1)),
|
|
|
|
}
|
|
|
|
|
|
|
|
# For now we use the same heuristic for all types
|
|
|
|
# Heuristic is currently tuned for H100s
|
|
|
|
default_heuristic = [
|
|
|
|
(cond, ScheduleConfig(*tile_config,
|
|
|
|
**sch_common_params)) # type: ignore
|
|
|
|
for cond, tile_config in default_tile_heuristic_config.items()
|
2024-08-20 09:09:33 -04:00
|
|
|
]
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
|
|
|
|
# Do not use schedules = list(set(...)) because we need to make sure
|
|
|
|
# the output list is deterministic; otherwise the generated kernel file
|
|
|
|
# will be non-deterministic and causes ccache miss.
|
|
|
|
schedules = []
|
|
|
|
for _, schedule_config in heuristic:
|
|
|
|
if schedule_config not in schedules:
|
|
|
|
schedules.append(schedule_config)
|
|
|
|
return schedules
|
2024-09-23 13:46:26 -04:00
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
impl_configs = []
|
|
|
|
|
|
|
|
GPTQ_kernel_type_configs = list(
|
2024-11-06 02:11:55 -05:00
|
|
|
TypeConfig(
|
2024-11-18 14:59:29 -05:00
|
|
|
a=a,
|
|
|
|
b=b,
|
|
|
|
b_group_scale=a,
|
|
|
|
b_group_zeropoint=DataType.void,
|
|
|
|
b_channel_scale=DataType.void,
|
|
|
|
a_token_scale=DataType.void,
|
|
|
|
out=a,
|
2024-08-20 09:09:33 -04:00
|
|
|
accumulator=DataType.f32,
|
2024-11-18 14:59:29 -05:00
|
|
|
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
|
|
|
for a in (DataType.f16, DataType.bf16))
|
2024-08-20 09:09:33 -04:00
|
|
|
|
|
|
|
impl_configs += [
|
2024-11-18 14:59:29 -05:00
|
|
|
ImplConfig(x[0], x[1], x[2])
|
|
|
|
for x in zip(GPTQ_kernel_type_configs,
|
|
|
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
2024-08-20 09:09:33 -04:00
|
|
|
itertools.repeat(default_heuristic))
|
|
|
|
]
|
|
|
|
|
|
|
|
AWQ_kernel_type_configs = list(
|
2024-11-06 02:11:55 -05:00
|
|
|
TypeConfig(
|
2024-11-18 14:59:29 -05:00
|
|
|
a=a,
|
|
|
|
b=b,
|
|
|
|
b_group_scale=a,
|
|
|
|
b_group_zeropoint=a,
|
|
|
|
b_channel_scale=DataType.void,
|
|
|
|
a_token_scale=DataType.void,
|
|
|
|
out=a,
|
2024-08-20 09:09:33 -04:00
|
|
|
accumulator=DataType.f32,
|
2024-11-18 14:59:29 -05:00
|
|
|
) for b in (DataType.u4, DataType.u8)
|
|
|
|
for a in (DataType.f16, DataType.bf16))
|
|
|
|
|
|
|
|
impl_configs += [
|
|
|
|
ImplConfig(x[0], x[1], x[2])
|
|
|
|
for x in zip(AWQ_kernel_type_configs,
|
|
|
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
|
|
|
itertools.repeat(default_heuristic))
|
|
|
|
]
|
2024-08-20 09:09:33 -04:00
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
|
|
|
# TODO (LucasWilkinson): Further tuning required
|
|
|
|
qqq_tile_heuristic_config = {
|
|
|
|
#### M = 257+
|
|
|
|
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
|
|
|
# TODO (LucasWilkinson): Investigate further
|
|
|
|
# "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
|
|
|
# "M > 256": ((128, 256), (2, 1, 1)),
|
|
|
|
"M > 256": ((128, 128), (2, 1, 1)),
|
|
|
|
#### M = 129-256
|
|
|
|
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
|
|
|
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
|
|
|
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
|
|
|
# TODO (LucasWilkinson): Investigate further
|
|
|
|
# "M > 128": ((128, 256), (2, 1, 1)),
|
|
|
|
"M > 128": ((128, 128), (2, 1, 1)),
|
|
|
|
#### M = 65-128
|
|
|
|
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
|
|
|
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
|
|
|
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
|
|
|
"M > 64": ((128, 128), (2, 1, 1)),
|
|
|
|
#### M = 33-64
|
|
|
|
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
|
|
|
# Broken for QQQ types
|
|
|
|
# TODO (LucasWilkinson): Investigate further
|
|
|
|
#"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
|
|
|
"M > 32": ((128, 64), (2, 1, 1)),
|
|
|
|
#### M = 17-32
|
|
|
|
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
|
|
|
"M > 16": ((256, 32), (2, 1, 1)),
|
|
|
|
#### M = 1-16
|
|
|
|
"N >= 26624": ((256, 16), (1, 1, 1)),
|
|
|
|
None: ((128, 16), (1, 1, 1)),
|
|
|
|
}
|
|
|
|
|
|
|
|
# For now we use the same heuristic for all types
|
|
|
|
# Heuristic is currently tuned for H100s
|
|
|
|
qqq_heuristic = [
|
|
|
|
(cond, ScheduleConfig(*tile_config,
|
|
|
|
**sch_common_params)) # type: ignore
|
|
|
|
for cond, tile_config in qqq_tile_heuristic_config.items()
|
|
|
|
]
|
|
|
|
|
|
|
|
QQQ_kernel_types = [
|
|
|
|
*(TypeConfig(
|
|
|
|
a=DataType.s8,
|
|
|
|
b=VLLMDataType.u4b8,
|
|
|
|
b_group_scale=b_group_scale,
|
|
|
|
b_group_zeropoint=DataType.void,
|
|
|
|
b_channel_scale=DataType.f32,
|
|
|
|
a_token_scale=DataType.f32,
|
|
|
|
out=DataType.f16,
|
|
|
|
accumulator=DataType.s32,
|
|
|
|
) for b_group_scale in (DataType.f16, DataType.void)),
|
|
|
|
*(TypeConfig(
|
|
|
|
a=DataType.e4m3,
|
|
|
|
b=VLLMDataType.u4b8,
|
|
|
|
b_group_scale=b_group_scale,
|
|
|
|
b_group_zeropoint=DataType.void,
|
|
|
|
b_channel_scale=DataType.f32,
|
|
|
|
a_token_scale=DataType.f32,
|
|
|
|
out=DataType.f16,
|
|
|
|
accumulator=DataType.f32,
|
|
|
|
) for b_group_scale in (DataType.f16, DataType.void)),
|
2024-08-20 09:09:33 -04:00
|
|
|
]
|
|
|
|
|
|
|
|
impl_configs += [
|
2024-11-18 14:59:29 -05:00
|
|
|
ImplConfig(x[0], x[1], x[2])
|
|
|
|
for x in zip(QQQ_kernel_types,
|
|
|
|
itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
|
|
|
itertools.repeat(qqq_heuristic))
|
2024-08-20 09:09:33 -04:00
|
|
|
]
|
|
|
|
|
|
|
|
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
|
|
|
|
|
|
|
# Delete the "generated" directory if it exists
|
|
|
|
if os.path.exists(output_dir):
|
|
|
|
shutil.rmtree(output_dir)
|
|
|
|
|
|
|
|
# Create the "generated" directory
|
|
|
|
os.makedirs(output_dir)
|
|
|
|
|
|
|
|
# Render each group of configurations into separate files
|
2024-11-18 14:59:29 -05:00
|
|
|
for filename, code in create_sources(impl_configs):
|
|
|
|
filepath = os.path.join(output_dir, f"{filename}.cu")
|
|
|
|
with open(filepath, "w") as output_file:
|
|
|
|
output_file.write(code)
|
|
|
|
print(f"Rendered template to {filepath}")
|
2024-08-20 09:09:33 -04:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
generate()
|