[BugFix] Handle non-contiguous tensors properly when serializing (#16492)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-04-11 17:54:06 -07:00 committed by GitHub
parent 57504a4bcf
commit 41cc883c29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 11 deletions

View File

@ -22,6 +22,10 @@ class MyType:
list_of_tensors: list[torch.Tensor] list_of_tensors: list[torch.Tensor]
numpy_array: np.ndarray numpy_array: np.ndarray
unrecognized: UnrecognizedType unrecognized: UnrecognizedType
small_f_contig_tensor: torch.Tensor
large_f_contig_tensor: torch.Tensor
small_non_contig_tensor: torch.Tensor
large_non_contig_tensor: torch.Tensor
def test_encode_decode(): def test_encode_decode():
@ -40,6 +44,10 @@ def test_encode_decode():
], ],
numpy_array=np.arange(512), numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33), unrecognized=UnrecognizedType(33),
small_f_contig_tensor=torch.rand(5, 4).t(),
large_f_contig_tensor=torch.rand(1024, 4).t(),
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
) )
encoder = MsgpackEncoder() encoder = MsgpackEncoder()
@ -47,10 +55,10 @@ def test_encode_decode():
encoded = encoder.encode(obj) encoded = encoder.encode(obj)
# There should be the main buffer + 2 large tensor buffers # There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 256 bytes. # + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline. # The two small tensors are encoded inline.
assert len(encoded) == 4 assert len(encoded) == 6
decoded: MyType = decoder.decode(encoded) decoded: MyType = decoder.decode(encoded)
@ -62,7 +70,7 @@ def test_encode_decode():
encoded2 = encoder.encode_into(obj, preallocated) encoded2 = encoder.encode_into(obj, preallocated)
assert len(encoded2) == 4 assert len(encoded2) == 6
assert encoded2[0] is preallocated assert encoded2[0] is preallocated
decoded2: MyType = decoder.decode(encoded2) decoded2: MyType = decoder.decode(encoded2)
@ -78,3 +86,9 @@ def assert_equal(obj1: MyType, obj2: MyType):
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)) for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
assert np.array_equal(obj1.numpy_array, obj2.numpy_array) assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor)
assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor)
assert torch.equal(obj1.small_non_contig_tensor,
obj2.small_non_contig_tensor)
assert torch.equal(obj1.large_non_contig_tensor,
obj2.large_non_contig_tensor)

View File

@ -14,9 +14,10 @@ from msgspec import msgpack
CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3
# TODO calibrate this size # TODO calibrate this size
INLINE_BUF_SIZE_THRESHOLD = 256 MIN_NOCOPY_BUF_SIZE = 512
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
@ -76,14 +77,16 @@ class MsgpackEncoder:
self, obj: np.ndarray self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None assert self.aux_buffers is not None
if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD: arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
# Encode small arrays and scalars inline. if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
data = obj.data # Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else: else:
# Otherwise encode index of backing buffer. # Otherwise encode index of backing buffer to avoid copy.
obj = np.ascontiguousarray(obj)
data = len(self.aux_buffers) data = len(self.aux_buffers)
self.aux_buffers.append(obj.data) self.aux_buffers.append(arr_data)
# We serialize the ndarray as a tuple of native types. # We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of # The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`. # backing buffers that we've stashed in `aux_buffers`.
@ -131,6 +134,8 @@ class MsgpackDecoder:
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
def ext_hook(self, code: int, data: memoryview) -> Any: def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data
if code == CUSTOM_TYPE_PICKLE: if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data) return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE: if code == CUSTOM_TYPE_CLOUDPICKLE: