2024-08-20 09:09:33 -04:00
|
|
|
#include "machete_mm_launcher.cuh"
|
|
|
|
#include "machete_prepack_launcher.cuh"
|
|
|
|
#include "core/scalar_type.hpp"
|
|
|
|
|
2024-10-03 22:55:25 -04:00
|
|
|
#include "core/registration.h"
|
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
namespace machete {
|
|
|
|
|
|
|
|
using namespace vllm;
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
std::vector<std::string> supported_schedules(
|
|
|
|
at::ScalarType a_type, int64_t b_type_id,
|
|
|
|
c10::optional<at::ScalarType> maybe_group_scales_type,
|
|
|
|
c10::optional<at::ScalarType> maybe_group_zeros_type,
|
|
|
|
c10::optional<at::ScalarType> maybe_channel_scales_type,
|
|
|
|
c10::optional<at::ScalarType> maybe_token_scales_type,
|
|
|
|
c10::optional<at::ScalarType> maybe_out_type) {
|
|
|
|
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
|
|
|
return supported_schedules_dispatch({
|
|
|
|
.a_type = a_type,
|
|
|
|
.b_type = b_type,
|
|
|
|
.maybe_group_scales_type = maybe_group_scales_type,
|
|
|
|
.maybe_group_zeros_type = maybe_group_zeros_type,
|
|
|
|
.maybe_channel_scales_type = maybe_channel_scales_type,
|
|
|
|
.maybe_token_scales_type = maybe_token_scales_type,
|
|
|
|
.maybe_out_type = maybe_out_type,
|
2024-08-20 09:09:33 -04:00
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
|
|
|
int64_t b_type_id,
|
|
|
|
c10::optional<at::ScalarType> const& maybe_out_type,
|
|
|
|
c10::optional<torch::Tensor> const& maybe_group_scales,
|
|
|
|
c10::optional<torch::Tensor> const& maybe_group_zeros,
|
|
|
|
c10::optional<int64_t> maybe_group_size,
|
|
|
|
c10::optional<torch::Tensor> const& maybe_channel_scales,
|
|
|
|
c10::optional<torch::Tensor> const& maybe_token_scales,
|
|
|
|
c10::optional<std::string> maybe_schedule) {
|
|
|
|
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
|
|
|
return mm_dispatch({.A = A,
|
|
|
|
.B = B,
|
|
|
|
.b_type = b_type,
|
|
|
|
.maybe_out_type = maybe_out_type,
|
|
|
|
.maybe_group_scales = maybe_group_scales,
|
|
|
|
.maybe_group_zeros = maybe_group_zeros,
|
|
|
|
.maybe_group_size = maybe_group_size,
|
|
|
|
.maybe_channel_scales = maybe_channel_scales,
|
|
|
|
.maybe_token_scales = maybe_token_scales,
|
|
|
|
.maybe_schedule = maybe_schedule});
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
|
|
|
|
2024-11-18 14:59:29 -05:00
|
|
|
torch::Tensor prepack_B(
|
|
|
|
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
|
|
|
|
c10::optional<at::ScalarType> const& maybe_group_scales_type) {
|
|
|
|
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
|
|
|
return prepack_B_dispatch(
|
|
|
|
{.B = B,
|
|
|
|
.a_type = a_type,
|
|
|
|
.b_type = b_type,
|
|
|
|
.maybe_group_scales_type = maybe_group_scales_type});
|
2024-10-03 22:55:25 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
|
|
|
m.impl("machete_prepack_B", &prepack_B);
|
2024-11-18 14:59:29 -05:00
|
|
|
m.impl("machete_mm", &mm);
|
2024-10-10 13:39:56 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
// use CatchAll since supported_schedules has no tensor arguments
|
|
|
|
TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) {
|
2024-10-03 22:55:25 -04:00
|
|
|
m.impl("machete_supported_schedules", &supported_schedules);
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
}; // namespace machete
|