[Kernel][Misc] dynamo support for ScalarType (#7594)

This commit is contained in:
bnellnm 2024-08-16 16:59:49 -04:00 committed by GitHub
parent 9f69856356
commit 7759ae958f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 149 additions and 24 deletions

View File

@ -313,6 +313,8 @@ class ScalarType {
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
// See also:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
public:
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
@ -382,6 +384,29 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
}
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t len() const {
throw c10::TypeError("__len__ not implemented");
return 0;
}
// Serialize a ScalarType into a tuple of pairs. Where each pair
// is a (fieldname, value).
// For simplicity, we are just going to convert to a ScalarTypeId.
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
return {{"ScalarType", id()}};
}
// Deserialize a scalar type that has been serialized by obj_flatten,
// ostensibly from a tuple of (member name, value) pairs, but in reality
// just a ScalarTypeId.
static SelfPtr obj_unflatten(
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
return c10::make_intrusive<Self>(
from_id(std::get<1>(std::get<0>(flat_type))));
}
template <typename T>
static void bind_readonly_property(torch::class_<Self>& cls,
std::string const& name, T Base::*field) {
@ -457,6 +482,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
self.get()->min());
});
bind_function(cls, "__len__", &ScalarTypeTorch::len);
bind_function(cls, "__str__", &Base::str);
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
return *self == *other;
@ -465,6 +491,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
return "ScalarType." + self.get()->str();
});
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
bind_static_function(cls, "__obj_unflatten__",
&ScalarTypeTorch::obj_unflatten);
// Bind static functions (convenience constructors)
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);

View File

@ -1,6 +1,6 @@
import importlib.util
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
import torch
@ -31,14 +31,14 @@ if TYPE_CHECKING or not core_C_available:
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
@ -51,15 +51,15 @@ if TYPE_CHECKING or not core_C_available:
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
@ -73,7 +73,7 @@ if TYPE_CHECKING or not core_C_available:
nan_repr: int = NanRepr.IEEE_754.value
"""
How NaNs are represent in this scalar type, returns NanRepr value.
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
@ -83,14 +83,14 @@ if TYPE_CHECKING or not core_C_available:
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
@ -103,28 +103,28 @@ if TYPE_CHECKING or not core_C_available:
"""
...
def is_floating_point(self):
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self):
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self):
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self):
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self):
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and \
@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available:
def __repr__(self) -> str:
raise NotImplementedError
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@ -153,16 +158,16 @@ if TYPE_CHECKING or not core_C_available:
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
"""
Create a standard floating point type
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True)
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int):
nan_repr: int) -> 'ScalarType':
"""
Create a non-standard floating point type
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True, finite_values_only,
@ -175,3 +180,93 @@ elif core_C_available:
logger.warning("Failed to import from vllm._core_C with %r", e)
ScalarType = torch.classes._core_C.ScalarType
# Needed for dynamo support of ScalarType.
@torch._library.register_fake_class("_core_C::ScalarType")
class FakeScalarType:
def __init__(self, scalar_type):
self.ScalarType = scalar_type
def bias_getter(self) -> int:
return self.ScalarType.bias
def exponent_getter(self) -> int:
return self.ScalarType.exponent
def mantissa_getter(self) -> int:
return self.ScalarType.mantissa
def signed_getter(self) -> bool:
return self.ScalarType.signed
def size_bits_getter(self) -> int:
return self.ScalarType.size_bits
@property
def size_bits(self) -> int:
return self.ScalarType.size_bits
def min(self) -> Union[int, float]:
return self.ScalarType.min()
def max(self) -> Union[int, float]:
return self.ScalarType.max()
def is_signed(self) -> bool:
return self.ScalarType.is_signed()
def is_floating_point(self) -> bool:
return self.ScalarType.is_floating_point()
def is_integer(self) -> bool:
return self.ScalarType.is_integer()
def has_bias(self) -> bool:
return self.ScalarType.has_bias()
def has_infs(self) -> bool:
return self.ScalarType.has_infs()
def has_nans(self) -> bool:
return self.ScalarType.has_nans()
def is_ieee_754(self) -> bool:
return self.ScalarType.is_ieee_754()
def __str__(self) -> str:
return self.ScalarType.__str__()
def __repr__(self) -> str:
return self.ScalarType.__repr__()
def __len__(self) -> int:
return self.ScalarType.__len__()
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
return torch.classes._core_C.ScalarType.__obj_flatten__(
self.ScalarType)
@classmethod
def __obj_unflatten__(
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
return cls(
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.int_(size_bits, bias)
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.uint(size_bits, bias)
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
return ScalarType.float_IEEE754(exponent, mantissa)
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
return ScalarType.float_(exponent, mantissa, finite_values_only,
nan_repr)