2023-09-02 14:59:47 +09:00
|
|
|
/*
|
|
|
|
* Adapted from
|
|
|
|
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
|
|
|
*/
|
2023-12-14 12:35:58 -05:00
|
|
|
#pragma once
|
|
|
|
|
2023-09-02 14:59:47 +09:00
|
|
|
#include <torch/extension.h>
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
2023-09-02 14:59:47 +09:00
|
|
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
|
|
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
2024-01-29 08:43:54 +08:00
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
|
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
2024-01-29 08:43:54 +08:00
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
|
|
|
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
|
|
|
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
|
|
|
|
|
|
|
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
2024-01-30 13:19:48 +08:00
|
|
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
|
|
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|