[Build][Kernel] Update CUTLASS to v3.6.0 (#11607)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2024-12-30 04:22:13 -05:00 committed by GitHub
parent 628ec6c17b
commit 970d6d0776
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 25 additions and 31 deletions

View File

@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare( FetchContent_Declare(
cutlass cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 GIT_TAG v3.6.0
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW FALSE GIT_SHALLOW TRUE
) )
endif() endif()
FetchContent_MakeAvailable(cutlass) FetchContent_MakeAvailable(cutlass)

View File

@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum):
class MixedInputKernelScheduleType(enum.Enum): class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedMixedInput = enum_auto() TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedPingpongMixedInput = enum_auto() TmaWarpSpecializedPingpong = enum_auto()
TmaWarpSpecializedCooperativeMixedInput = enum_auto() TmaWarpSpecializedCooperative = enum_auto()
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
@ -68,11 +68,11 @@ VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = { MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore **KernelScheduleTag, # type: ignore
**{ **{
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: MixedInputKernelScheduleType.TmaWarpSpecialized:
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput", "cutlass::gemm::KernelTmaWarpSpecialized",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
} }
} }

View File

@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>; Sch>;
{% for sch in schs %} {% for sch in schs %}
@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
{{DataTypeTag[t.convert]}}, // ElementConvert {{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator {{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput> cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B); >(args.B);
} }
{%- endfor %} {%- endfor %}
@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
}; // namespace machete }; // namespace machete
""" """
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
# mostly unique shorter sch_sig # mostly unique shorter sch_sig
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = { kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", "KernelTmaWarpSpecializedCooperative": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_", "TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK", "StreamKScheduler": "streamK",
} }

View File

@ -18,12 +18,10 @@ struct VLLMCollectiveBuilder<
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType, KernelScheduleType,
cute::enable_if_t<( cute::enable_if_t<(
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedMixedInput> || KernelTmaWarpSpecializedCooperative>)>> {
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
using CollectiveOp = machete::MacheteCollectiveMma< using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,

View File

@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
using Schedule = KernelScheduleType; using Schedule = KernelScheduleType;
static_assert( static_assert(
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> || cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> || cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> || cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> || cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<Schedule, cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
KernelTmaWarpSpecializedCooperativeMixedInput>,
"KernelSchedule must be one of the warp specialized policies"); "KernelSchedule must be one of the warp specialized policies");
public: public:
@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
// For coop schedules we have two warp groups cooperatively issuing wgmma // For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup) // instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t< using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType, cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>; Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma( using TiledMma = decltype(cute::make_tiled_mma(

View File

@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
// For coop schedules we have two warp groups cooperatively issuing wgmma // For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup) // instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t< using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule, cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>; Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma( using TiledMma = decltype(cute::make_tiled_mma(