[Kernel][Misc] dynamo support for ScalarType (#7594)
This commit is contained in:
parent
9f69856356
commit
7759ae958f
@ -313,6 +313,8 @@ class ScalarType {
|
||||
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
||||
// constructor at the same time (torch::CustomClassHolder does not have a
|
||||
// constexpr destructor)
|
||||
// See also:
|
||||
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
public:
|
||||
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
||||
@ -382,6 +384,29 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
||||
}
|
||||
|
||||
// This needs to be implemented and throw a TypeError in order for
|
||||
// PyTorch's opcheck to work on ops that use ScalarTypes.
|
||||
int64_t len() const {
|
||||
throw c10::TypeError("__len__ not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Serialize a ScalarType into a tuple of pairs. Where each pair
|
||||
// is a (fieldname, value).
|
||||
// For simplicity, we are just going to convert to a ScalarTypeId.
|
||||
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
|
||||
return {{"ScalarType", id()}};
|
||||
}
|
||||
|
||||
// Deserialize a scalar type that has been serialized by obj_flatten,
|
||||
// ostensibly from a tuple of (member name, value) pairs, but in reality
|
||||
// just a ScalarTypeId.
|
||||
static SelfPtr obj_unflatten(
|
||||
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
|
||||
return c10::make_intrusive<Self>(
|
||||
from_id(std::get<1>(std::get<0>(flat_type))));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void bind_readonly_property(torch::class_<Self>& cls,
|
||||
std::string const& name, T Base::*field) {
|
||||
@ -457,6 +482,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
self.get()->min());
|
||||
});
|
||||
|
||||
bind_function(cls, "__len__", &ScalarTypeTorch::len);
|
||||
bind_function(cls, "__str__", &Base::str);
|
||||
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
||||
return *self == *other;
|
||||
@ -465,6 +491,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
return "ScalarType." + self.get()->str();
|
||||
});
|
||||
|
||||
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
|
||||
bind_static_function(cls, "__obj_unflatten__",
|
||||
&ScalarTypeTorch::obj_unflatten);
|
||||
|
||||
// Bind static functions (convenience constructors)
|
||||
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
||||
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
||||
|
@ -1,6 +1,6 @@
|
||||
import importlib.util
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -103,23 +103,23 @@ if TYPE_CHECKING or not core_C_available:
|
||||
"""
|
||||
...
|
||||
|
||||
def is_floating_point(self):
|
||||
def is_floating_point(self) -> bool:
|
||||
"If the type is a floating point type"
|
||||
return self.exponent != 0
|
||||
|
||||
def is_integer(self):
|
||||
def is_integer(self) -> bool:
|
||||
"If the type is an integer type"
|
||||
return self.exponent == 0
|
||||
|
||||
def has_bias(self):
|
||||
def has_bias(self) -> bool:
|
||||
"If the type has a non-zero bias"
|
||||
return self.bias != 0
|
||||
|
||||
def has_infs(self):
|
||||
def has_infs(self) -> bool:
|
||||
"If the type is floating point and supports infinity"
|
||||
return not self._finite_values_only
|
||||
|
||||
def has_nans(self):
|
||||
def has_nans(self) -> bool:
|
||||
return self.nan_repr != NanRepr.NONE.value
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available:
|
||||
def __repr__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||
# opcheck to work.
|
||||
def __len__(self) -> int:
|
||||
raise TypeError
|
||||
|
||||
#
|
||||
# Convenience Constructors
|
||||
#
|
||||
@ -160,7 +165,7 @@ if TYPE_CHECKING or not core_C_available:
|
||||
|
||||
@classmethod
|
||||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||
nan_repr: int):
|
||||
nan_repr: int) -> 'ScalarType':
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
@ -175,3 +180,93 @@ elif core_C_available:
|
||||
logger.warning("Failed to import from vllm._core_C with %r", e)
|
||||
|
||||
ScalarType = torch.classes._core_C.ScalarType
|
||||
|
||||
# Needed for dynamo support of ScalarType.
|
||||
@torch._library.register_fake_class("_core_C::ScalarType")
|
||||
class FakeScalarType:
|
||||
|
||||
def __init__(self, scalar_type):
|
||||
self.ScalarType = scalar_type
|
||||
|
||||
def bias_getter(self) -> int:
|
||||
return self.ScalarType.bias
|
||||
|
||||
def exponent_getter(self) -> int:
|
||||
return self.ScalarType.exponent
|
||||
|
||||
def mantissa_getter(self) -> int:
|
||||
return self.ScalarType.mantissa
|
||||
|
||||
def signed_getter(self) -> bool:
|
||||
return self.ScalarType.signed
|
||||
|
||||
def size_bits_getter(self) -> int:
|
||||
return self.ScalarType.size_bits
|
||||
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.ScalarType.size_bits
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
return self.ScalarType.min()
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
return self.ScalarType.max()
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
return self.ScalarType.is_signed()
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
return self.ScalarType.is_floating_point()
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
return self.ScalarType.is_integer()
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
return self.ScalarType.has_bias()
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
return self.ScalarType.has_infs()
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.ScalarType.has_nans()
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
return self.ScalarType.is_ieee_754()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.ScalarType.__str__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.ScalarType.__repr__()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.ScalarType.__len__()
|
||||
|
||||
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
|
||||
return torch.classes._core_C.ScalarType.__obj_flatten__(
|
||||
self.ScalarType)
|
||||
|
||||
@classmethod
|
||||
def __obj_unflatten__(
|
||||
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
|
||||
return cls(
|
||||
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
return ScalarType.int_(size_bits, bias)
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
return ScalarType.uint(size_bits, bias)
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
||||
return ScalarType.float_IEEE754(exponent, mantissa)
|
||||
|
||||
@classmethod
|
||||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||
nan_repr: int) -> 'ScalarType':
|
||||
return ScalarType.float_(exponent, mantissa, finite_values_only,
|
||||
nan_repr)
|
||||
|
Loading…
x
Reference in New Issue
Block a user