[Build][Kernel] Update CUTLASS to v3.6.0 (#11607)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
628ec6c17b
commit
970d6d0776
@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
|
||||
GIT_TAG v3.6.0
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
|
||||
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
|
||||
GIT_SHALLOW FALSE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
endif()
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum):
|
||||
|
||||
|
||||
class MixedInputKernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecializedMixedInput = enum_auto()
|
||||
TmaWarpSpecializedPingpongMixedInput = enum_auto()
|
||||
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedPingpong = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
|
||||
|
||||
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
@ -68,11 +68,11 @@ VLLMKernelScheduleTag: Dict[Union[
|
||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
**{
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecialized:
|
||||
"cutlass::gemm::KernelTmaWarpSpecialized",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||
}
|
||||
}
|
||||
|
@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Sch>;
|
||||
|
||||
{% for sch in schs %}
|
||||
@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
||||
>(args.B);
|
||||
}
|
||||
{%- endfor %}
|
||||
@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
# mostly unique shorter sch_sig
|
||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
||||
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
|
||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<(
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedMixedInput> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
|
||||
KernelTmaWarpSpecializedCooperative>)>> {
|
||||
using CollectiveOp = machete::MacheteCollectiveMma<
|
||||
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
||||
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
||||
StageCountType, KernelScheduleType>;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
||||
}; // namespace cutlass::gemm::collective
|
||||
|
@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
|
||||
using Schedule = KernelScheduleType;
|
||||
static_assert(
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<Schedule,
|
||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<Schedule,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
|
||||
public:
|
||||
@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
|
@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelSchedule,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
}; // namespace machete
|
||||
|
Loading…
x
Reference in New Issue
Block a user