[Build][Kernel] Update CUTLASS to v3.6.0 (#11607)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2024-12-30 04:22:13 -05:00 committed by GitHub
parent 628ec6c17b
commit 970d6d0776
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 25 additions and 31 deletions

View File

@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
GIT_TAG v3.6.0
GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW FALSE
GIT_SHALLOW TRUE
)
endif()
FetchContent_MakeAvailable(cutlass)

View File

@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum):
class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedMixedInput = enum_auto()
TmaWarpSpecializedPingpongMixedInput = enum_auto()
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedPingpong = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
@ -68,11 +68,11 @@ VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
MixedInputKernelScheduleType.TmaWarpSpecialized:
"cutlass::gemm::KernelTmaWarpSpecialized",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
}
}

View File

@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>;
{% for sch in schs %}
@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B);
}
{%- endfor %}
@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
}; // namespace machete
"""
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
# mostly unique shorter sch_sig
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK",
}

View File

@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType,
cute::enable_if_t<(
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
KernelTmaWarpSpecializedCooperative>)>> {
using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>;
};
}; // namespace cutlass::gemm::collective
}; // namespace cutlass::gemm::collective

View File

@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
using Schedule = KernelScheduleType;
static_assert(
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedCooperativeMixedInput>,
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
"KernelSchedule must be one of the warp specialized policies");
public:
@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>,
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(

View File

@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule,
KernelTmaWarpSpecializedCooperativeMixedInput>,
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
}
};
}; // namespace machete
}; // namespace machete