[Kernel][Misc] dynamo support for ScalarType (#7594)

This commit is contained in:
bnellnm 2024-08-16 16:59:49 -04:00 committed by GitHub
parent 9f69856356
commit 7759ae958f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 149 additions and 24 deletions

View File

@ -313,6 +313,8 @@ class ScalarType {
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
// See also:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
public:
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
@ -382,6 +384,29 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
}
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t len() const {
throw c10::TypeError("__len__ not implemented");
return 0;
}
// Serialize a ScalarType into a tuple of pairs. Where each pair
// is a (fieldname, value).
// For simplicity, we are just going to convert to a ScalarTypeId.
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
return {{"ScalarType", id()}};
}
// Deserialize a scalar type that has been serialized by obj_flatten,
// ostensibly from a tuple of (member name, value) pairs, but in reality
// just a ScalarTypeId.
static SelfPtr obj_unflatten(
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
return c10::make_intrusive<Self>(
from_id(std::get<1>(std::get<0>(flat_type))));
}
template <typename T>
static void bind_readonly_property(torch::class_<Self>& cls,
std::string const& name, T Base::*field) {
@ -457,6 +482,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
self.get()->min());
});
bind_function(cls, "__len__", &ScalarTypeTorch::len);
bind_function(cls, "__str__", &Base::str);
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
return *self == *other;
@ -465,6 +491,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
return "ScalarType." + self.get()->str();
});
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
bind_static_function(cls, "__obj_unflatten__",
&ScalarTypeTorch::obj_unflatten);
// Bind static functions (convenience constructors)
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);

View File

@ -1,6 +1,6 @@
import importlib.util
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
import torch
@ -103,23 +103,23 @@ if TYPE_CHECKING or not core_C_available:
"""
...
def is_floating_point(self):
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self):
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self):
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self):
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self):
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available:
def __repr__(self) -> str:
raise NotImplementedError
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@ -160,7 +165,7 @@ if TYPE_CHECKING or not core_C_available:
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int):
nan_repr: int) -> 'ScalarType':
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
@ -175,3 +180,93 @@ elif core_C_available:
logger.warning("Failed to import from vllm._core_C with %r", e)
ScalarType = torch.classes._core_C.ScalarType
# Needed for dynamo support of ScalarType.
@torch._library.register_fake_class("_core_C::ScalarType")
class FakeScalarType:
def __init__(self, scalar_type):
self.ScalarType = scalar_type
def bias_getter(self) -> int:
return self.ScalarType.bias
def exponent_getter(self) -> int:
return self.ScalarType.exponent
def mantissa_getter(self) -> int:
return self.ScalarType.mantissa
def signed_getter(self) -> bool:
return self.ScalarType.signed
def size_bits_getter(self) -> int:
return self.ScalarType.size_bits
@property
def size_bits(self) -> int:
return self.ScalarType.size_bits
def min(self) -> Union[int, float]:
return self.ScalarType.min()
def max(self) -> Union[int, float]:
return self.ScalarType.max()
def is_signed(self) -> bool:
return self.ScalarType.is_signed()
def is_floating_point(self) -> bool:
return self.ScalarType.is_floating_point()
def is_integer(self) -> bool:
return self.ScalarType.is_integer()
def has_bias(self) -> bool:
return self.ScalarType.has_bias()
def has_infs(self) -> bool:
return self.ScalarType.has_infs()
def has_nans(self) -> bool:
return self.ScalarType.has_nans()
def is_ieee_754(self) -> bool:
return self.ScalarType.is_ieee_754()
def __str__(self) -> str:
return self.ScalarType.__str__()
def __repr__(self) -> str:
return self.ScalarType.__repr__()
def __len__(self) -> int:
return self.ScalarType.__len__()
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
return torch.classes._core_C.ScalarType.__obj_flatten__(
self.ScalarType)
@classmethod
def __obj_unflatten__(
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
return cls(
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.int_(size_bits, bias)
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.uint(size_bits, bias)
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
return ScalarType.float_IEEE754(exponent, mantissa)
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
return ScalarType.float_(exponent, mantissa, finite_values_only,
nan_repr)