diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index bc0e0cbd..e58d3c40 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 from collections import UserDict from dataclasses import dataclass +from typing import Optional +import msgspec import numpy as np import torch +from vllm.multimodal.inputs import (MultiModalBatchedField, + MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, + MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -50,7 +56,7 @@ def test_encode_decode(): large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], ) - encoder = MsgpackEncoder() + encoder = MsgpackEncoder(size_threshold=256) decoder = MsgpackDecoder(MyType) encoded = encoder.encode(obj) @@ -78,6 +84,97 @@ def test_encode_decode(): assert_equal(decoded2, obj) +class MyRequest(msgspec.Struct): + mm: Optional[list[MultiModalKwargs]] + + +def test_multimodal_kwargs(): + d = { + "foo": + torch.zeros(20000, dtype=torch.float16), + "bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], + "baz": [ + torch.rand((256), dtype=torch.float16), + [ + torch.rand((1, 12), dtype=torch.float32), + torch.rand((3, 5, 7), dtype=torch.float64), + ], [torch.rand((4, 4), dtype=torch.float16)] + ], + } + + # pack mm kwargs into a mock request so that it can be decoded properly + req = MyRequest(mm=[MultiModalKwargs(d)]) + + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(MyRequest) + + encoded = encoder.encode(req) + + assert len(encoded) == 6 + + total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) + + # expected total encoding length, should be 44536, +-20 for minor changes + assert total_len >= 44516 and total_len <= 44556 + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + assert all(nested_equal(d[k], decoded[k]) for k in d) + + +def test_multimodal_items_by_modality(): + e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, + dtype=torch.int16), + MultiModalBatchedField()) + e2 = MultiModalFieldElem( + "video", + "v0", + [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], + MultiModalBatchedField(), + ) + e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, + dtype=torch.int32), + MultiModalSharedField(4)) + e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000, + dtype=torch.int32), + MultiModalBatchedField()) + audio = MultiModalKwargsItem.from_elems([e1]) + video = MultiModalKwargsItem.from_elems([e2]) + image = MultiModalKwargsItem.from_elems([e3, e4]) + mm = MultiModalKwargs.from_items([audio, video, image]) + + # pack mm kwargs into a mock request so that it can be decoded properly + req = MyRequest([mm]) + + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(MyRequest) + + encoded = encoder.encode(req) + + assert len(encoded) == 8 + + total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) + + # expected total encoding length, should be 14255, +-20 for minor changes + assert total_len >= 14235 and total_len <= 14275 + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + + # check all modalities were recovered and do some basic sanity checks + assert len(decoded.modalities) == 3 + images = decoded.get_items("image") + assert len(images) == 1 + assert len(images[0].items()) == 2 + assert list(images[0].keys()) == ["i0", "i1"] + + # check the tensor contents and layout in the main dict + assert all(nested_equal(mm[k], decoded[k]) for k in mm) + + +def nested_equal(a: NestedTensors, b: NestedTensors): + if isinstance(a, torch.Tensor): + return torch.equal(a, b) + else: + return all(nested_equal(x, y) for x, y in zip(a, b)) + + def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.tensor1, obj2.tensor1) assert obj1.a_string == obj2.a_string diff --git a/vllm/envs.py b/vllm/envs.py index f80bf878..d32968c3 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -107,6 +107,7 @@ if TYPE_CHECKING: VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 + VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 def get_default_cache_root(): @@ -704,6 +705,16 @@ environment_variables: dict[str, Callable[[], Any]] = { # It can be changed with this variable if needed for some reason. "VLLM_XGRAMMAR_CACHE_MB": lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), + + # Control the threshold for msgspec to use 'zero copy' for + # serialization/deserialization of tensors. Tensors below + # this limit will be encoded into the msgpack buffer, and + # tensors above will instead be sent via a separate message. + # While the sending side still actually copies the tensor + # in all cases, on the receiving side, tensors above this + # limit will actually be zero-copy decoded. + "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": + lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), } # end-env-vars-definition diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 3af6793f..4f7987ee 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import dataclasses import pickle from collections.abc import Sequence from inspect import isclass @@ -12,12 +13,26 @@ import torch import zmq from msgspec import msgpack +from vllm import envs +from vllm.multimodal.inputs import (BaseMultiModalField, + MultiModalBatchedField, + MultiModalFieldConfig, MultiModalFieldElem, + MultiModalFlatField, MultiModalKwargs, + MultiModalKwargsItem, + MultiModalSharedField, NestedTensors) + CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 -# TODO calibrate this size -MIN_NOCOPY_BUF_SIZE = 512 +# MultiModalField class serialization type map. +# These need to list all possible field types and match them +# to factory methods in `MultiModalFieldConfig`. +MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = { + MultiModalFlatField: "flat", + MultiModalSharedField: "shared", + MultiModalBatchedField: "batched", +} bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] @@ -27,14 +42,20 @@ class MsgpackEncoder: Note that unlike vanilla `msgspec` Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. + + By default, arrays below 256B are serialized inline Larger will get sent + via dedicated messages. Note that this is a per-tensor limit. """ - def __init__(self): + def __init__(self, size_threshold: Optional[int] = None): + if size_threshold is None: + size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to # pass custom data to the hook otherwise. self.aux_buffers: Optional[list[bytestr]] = None + self.size_threshold = size_threshold def encode(self, obj: Any) -> Sequence[bytestr]: try: @@ -65,6 +86,25 @@ class MsgpackEncoder: if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): return self._encode_ndarray(obj) + if isinstance(obj, MultiModalKwargs): + mm: MultiModalKwargs = obj + if not mm.modalities: + # just return the main dict if there are no modalities. + return dict(mm) + + # ignore the main dict, it will be re-indexed. + # Encode a list of MultiModalKwargsItems as plain dicts + # + special handling for .field. + # Any tensors *not* indexed by modality will be ignored. + return [[{ + "modality": elem.modality, + "key": elem.key, + "data": self._encode_nested_tensors(elem.data), + "field": self._encode_mm_field(elem.field), + } for elem in item.values()] + for itemlist in mm._items_by_modality.values() + for item in itemlist] + if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have # problems serializing methods. @@ -77,8 +117,9 @@ class MsgpackEncoder: self, obj: np.ndarray ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None + # If the array is non-contiguous, we need to copy it first arr_data = obj.data if obj.data.c_contiguous else obj.tobytes() - if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE: + if not obj.shape or obj.nbytes < self.size_threshold: # Encode small arrays and scalars inline. Using this extension type # ensures we can avoid copying when decoding. data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data) @@ -92,6 +133,26 @@ class MsgpackEncoder: # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: + if isinstance(nt, torch.Tensor): + return self._encode_ndarray(nt.numpy()) + if isinstance(nt, (int, float)): + # Although it violates NestedTensors type, MultiModalKwargs + # values are sometimes floats. + return nt + return [self._encode_nested_tensors(x) for x in nt] + + def _encode_mm_field(self, field: BaseMultiModalField): + # Figure out the factory name for the field type. + name = MMF_CLASS_TO_FACTORY.get(field.__class__) + if not name: + raise TypeError(f"Unsupported field type: {field.__class__}") + # We just need to copy all of the field values in order + # which will be then used to reconstruct the field. + field_values = (getattr(field, f.name) + for f in dataclasses.fields(field)) + return name, *field_values + class MsgpackDecoder: """Decoder with custom torch tensor and numpy array serialization. @@ -126,13 +187,50 @@ class MsgpackDecoder: return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): return torch.from_numpy(self._decode_ndarray(obj)) + if issubclass(t, MultiModalKwargs): + if isinstance(obj, list): + return MultiModalKwargs.from_items( + self._decode_mm_items(obj)) + return MultiModalKwargs({ + k: self._decode_nested_tensors(v) + for k, v in obj.items() + }) return obj def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr - buffer = self.aux_buffers[data] if isinstance(data, int) else data + # Copy from inline representation, otherwise Torch is unhappy since + # the returned memory is non-writeable. + buffer = self.aux_buffers[data] if isinstance(data, int) \ + else bytearray(data) return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) + def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: + decoded_items = [] + for item in obj: + elems = [] + for v in item: + v["data"] = self._decode_nested_tensors(v["data"]) + # Reconstruct the field processor using MultiModalFieldConfig + factory_meth_name, *field_args = v["field"] + factory_meth = getattr(MultiModalFieldConfig, + factory_meth_name) + v["field"] = factory_meth(None, *field_args).field + elems.append(MultiModalFieldElem(**v)) + decoded_items.append(MultiModalKwargsItem.from_elems(elems)) + return decoded_items + + def _decode_nested_tensors(self, obj: Any) -> NestedTensors: + if isinstance(obj, (int, float)): + # Although it violates NestedTensors type, MultiModalKwargs + # values are sometimes floats. + return obj + if not isinstance(obj, list): + raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") + if obj and isinstance(obj[0], str): + return torch.from_numpy(self._decode_ndarray(obj)) + return [self._decode_nested_tensors(x) for x in obj] + def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data