[Build][Kernel] Update CUTLASS to v3.6.0 (#11607)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
628ec6c17b
commit
970d6d0776
@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
cutlass
|
cutlass
|
||||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||||
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
|
GIT_TAG v3.6.0
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
|
|
||||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||||
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
|
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
|
||||||
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
|
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
|
||||||
GIT_SHALLOW FALSE
|
GIT_SHALLOW TRUE
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
FetchContent_MakeAvailable(cutlass)
|
FetchContent_MakeAvailable(cutlass)
|
||||||
|
@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum):
|
|||||||
|
|
||||||
|
|
||||||
class MixedInputKernelScheduleType(enum.Enum):
|
class MixedInputKernelScheduleType(enum.Enum):
|
||||||
TmaWarpSpecializedMixedInput = enum_auto()
|
TmaWarpSpecialized = enum_auto()
|
||||||
TmaWarpSpecializedPingpongMixedInput = enum_auto()
|
TmaWarpSpecializedPingpong = enum_auto()
|
||||||
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
|
TmaWarpSpecializedCooperative = enum_auto()
|
||||||
|
|
||||||
|
|
||||||
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
||||||
@ -68,11 +68,11 @@ VLLMKernelScheduleTag: Dict[Union[
|
|||||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||||
**KernelScheduleTag, # type: ignore
|
**KernelScheduleTag, # type: ignore
|
||||||
**{
|
**{
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
|
MixedInputKernelScheduleType.TmaWarpSpecialized:
|
||||||
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
|
"cutlass::gemm::KernelTmaWarpSpecialized",
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
|
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
|
||||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
|
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
|
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
|
||||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
|
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
|||||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||||
Sch>;
|
Sch>;
|
||||||
|
|
||||||
{% for sch in schs %}
|
{% for sch in schs %}
|
||||||
@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
|||||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||||
cutlass::layout::ColumnMajor,
|
cutlass::layout::ColumnMajor,
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
||||||
>(args.B);
|
>(args.B);
|
||||||
}
|
}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
|||||||
}; // namespace machete
|
}; // namespace machete
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
||||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||||
|
|
||||||
|
|
||||||
@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
|||||||
# mostly unique shorter sch_sig
|
# mostly unique shorter sch_sig
|
||||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||||
kernel_terse_names_replace = {
|
kernel_terse_names_replace = {
|
||||||
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
||||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||||
"StreamKScheduler": "streamK",
|
"StreamKScheduler": "streamK",
|
||||||
}
|
}
|
||||||
|
@ -18,12 +18,10 @@ struct VLLMCollectiveBuilder<
|
|||||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
||||||
KernelScheduleType,
|
KernelScheduleType,
|
||||||
cute::enable_if_t<(
|
cute::enable_if_t<(
|
||||||
|
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||||
|
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||||
cute::is_same_v<KernelScheduleType,
|
cute::is_same_v<KernelScheduleType,
|
||||||
KernelTmaWarpSpecializedMixedInput> ||
|
KernelTmaWarpSpecializedCooperative>)>> {
|
||||||
cute::is_same_v<KernelScheduleType,
|
|
||||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
|
||||||
cute::is_same_v<KernelScheduleType,
|
|
||||||
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
|
|
||||||
using CollectiveOp = machete::MacheteCollectiveMma<
|
using CollectiveOp = machete::MacheteCollectiveMma<
|
||||||
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
||||||
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
||||||
|
@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
|
|||||||
using Schedule = KernelScheduleType;
|
using Schedule = KernelScheduleType;
|
||||||
static_assert(
|
static_assert(
|
||||||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
|
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||||
|
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||||
cute::is_same_v<Schedule,
|
|
||||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
|
||||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
|
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
|
||||||
cute::is_same_v<Schedule,
|
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
|
||||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
|
||||||
"KernelSchedule must be one of the warp specialized policies");
|
"KernelSchedule must be one of the warp specialized policies");
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
|
|||||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||||
using AtomLayoutMNK = cute::conditional_t<
|
using AtomLayoutMNK = cute::conditional_t<
|
||||||
cute::is_same_v<KernelScheduleType,
|
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
|
||||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
|
||||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||||
|
|
||||||
using TiledMma = decltype(cute::make_tiled_mma(
|
using TiledMma = decltype(cute::make_tiled_mma(
|
||||||
|
@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
|
|||||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||||
using AtomLayoutMNK = cute::conditional_t<
|
using AtomLayoutMNK = cute::conditional_t<
|
||||||
cute::is_same_v<KernelSchedule,
|
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
||||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
|
||||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||||
|
|
||||||
using TiledMma = decltype(cute::make_tiled_mma(
|
using TiledMma = decltype(cute::make_tiled_mma(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user