[BugFix] Handle non-contiguous tensors properly when serializing (#16492)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
57504a4bcf
commit
41cc883c29
@ -22,6 +22,10 @@ class MyType:
|
|||||||
list_of_tensors: list[torch.Tensor]
|
list_of_tensors: list[torch.Tensor]
|
||||||
numpy_array: np.ndarray
|
numpy_array: np.ndarray
|
||||||
unrecognized: UnrecognizedType
|
unrecognized: UnrecognizedType
|
||||||
|
small_f_contig_tensor: torch.Tensor
|
||||||
|
large_f_contig_tensor: torch.Tensor
|
||||||
|
small_non_contig_tensor: torch.Tensor
|
||||||
|
large_non_contig_tensor: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
def test_encode_decode():
|
def test_encode_decode():
|
||||||
@ -40,6 +44,10 @@ def test_encode_decode():
|
|||||||
],
|
],
|
||||||
numpy_array=np.arange(512),
|
numpy_array=np.arange(512),
|
||||||
unrecognized=UnrecognizedType(33),
|
unrecognized=UnrecognizedType(33),
|
||||||
|
small_f_contig_tensor=torch.rand(5, 4).t(),
|
||||||
|
large_f_contig_tensor=torch.rand(1024, 4).t(),
|
||||||
|
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
|
||||||
|
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder = MsgpackEncoder()
|
encoder = MsgpackEncoder()
|
||||||
@ -47,10 +55,10 @@ def test_encode_decode():
|
|||||||
|
|
||||||
encoded = encoder.encode(obj)
|
encoded = encoder.encode(obj)
|
||||||
|
|
||||||
# There should be the main buffer + 2 large tensor buffers
|
# There should be the main buffer + 4 large tensor buffers
|
||||||
# + 1 large numpy array. "large" is <= 256 bytes.
|
# + 1 large numpy array. "large" is <= 512 bytes.
|
||||||
# The two small tensors are encoded inline.
|
# The two small tensors are encoded inline.
|
||||||
assert len(encoded) == 4
|
assert len(encoded) == 6
|
||||||
|
|
||||||
decoded: MyType = decoder.decode(encoded)
|
decoded: MyType = decoder.decode(encoded)
|
||||||
|
|
||||||
@ -62,7 +70,7 @@ def test_encode_decode():
|
|||||||
|
|
||||||
encoded2 = encoder.encode_into(obj, preallocated)
|
encoded2 = encoder.encode_into(obj, preallocated)
|
||||||
|
|
||||||
assert len(encoded2) == 4
|
assert len(encoded2) == 6
|
||||||
assert encoded2[0] is preallocated
|
assert encoded2[0] is preallocated
|
||||||
|
|
||||||
decoded2: MyType = decoder.decode(encoded2)
|
decoded2: MyType = decoder.decode(encoded2)
|
||||||
@ -78,3 +86,9 @@ def assert_equal(obj1: MyType, obj2: MyType):
|
|||||||
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
|
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
|
||||||
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
|
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
|
||||||
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
|
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
|
||||||
|
assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor)
|
||||||
|
assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor)
|
||||||
|
assert torch.equal(obj1.small_non_contig_tensor,
|
||||||
|
obj2.small_non_contig_tensor)
|
||||||
|
assert torch.equal(obj1.large_non_contig_tensor,
|
||||||
|
obj2.large_non_contig_tensor)
|
||||||
|
@ -14,9 +14,10 @@ from msgspec import msgpack
|
|||||||
|
|
||||||
CUSTOM_TYPE_PICKLE = 1
|
CUSTOM_TYPE_PICKLE = 1
|
||||||
CUSTOM_TYPE_CLOUDPICKLE = 2
|
CUSTOM_TYPE_CLOUDPICKLE = 2
|
||||||
|
CUSTOM_TYPE_RAW_VIEW = 3
|
||||||
|
|
||||||
# TODO calibrate this size
|
# TODO calibrate this size
|
||||||
INLINE_BUF_SIZE_THRESHOLD = 256
|
MIN_NOCOPY_BUF_SIZE = 512
|
||||||
|
|
||||||
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
|
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
|
||||||
|
|
||||||
@ -76,14 +77,16 @@ class MsgpackEncoder:
|
|||||||
self, obj: np.ndarray
|
self, obj: np.ndarray
|
||||||
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||||
assert self.aux_buffers is not None
|
assert self.aux_buffers is not None
|
||||||
if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD:
|
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
|
||||||
# Encode small arrays and scalars inline.
|
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
|
||||||
data = obj.data
|
# Encode small arrays and scalars inline. Using this extension type
|
||||||
|
# ensures we can avoid copying when decoding.
|
||||||
|
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
|
||||||
else:
|
else:
|
||||||
# Otherwise encode index of backing buffer.
|
# Otherwise encode index of backing buffer to avoid copy.
|
||||||
obj = np.ascontiguousarray(obj)
|
|
||||||
data = len(self.aux_buffers)
|
data = len(self.aux_buffers)
|
||||||
self.aux_buffers.append(obj.data)
|
self.aux_buffers.append(arr_data)
|
||||||
|
|
||||||
# We serialize the ndarray as a tuple of native types.
|
# We serialize the ndarray as a tuple of native types.
|
||||||
# The data is either inlined if small, or an index into a list of
|
# The data is either inlined if small, or an index into a list of
|
||||||
# backing buffers that we've stashed in `aux_buffers`.
|
# backing buffers that we've stashed in `aux_buffers`.
|
||||||
@ -131,6 +134,8 @@ class MsgpackDecoder:
|
|||||||
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
|
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
|
||||||
|
|
||||||
def ext_hook(self, code: int, data: memoryview) -> Any:
|
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||||
|
if code == CUSTOM_TYPE_RAW_VIEW:
|
||||||
|
return data
|
||||||
if code == CUSTOM_TYPE_PICKLE:
|
if code == CUSTOM_TYPE_PICKLE:
|
||||||
return pickle.loads(data)
|
return pickle.loads(data)
|
||||||
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user