[V1][Performance] Implement custom serializaton for MultiModalKwargs [Rebased] (#16432)

Signed-off-by: Staszek Pasko <staszek@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Staszek Paśko 2025-04-17 04:28:32 +02:00 committed by GitHub
parent 3cd91dc955
commit 3092375e27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 212 additions and 6 deletions

View File

@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
from collections import UserDict
from dataclasses import dataclass
from typing import Optional
import msgspec
import numpy as np
import torch
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
@ -50,7 +56,7 @@ def test_encode_decode():
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
)
encoder = MsgpackEncoder()
encoder = MsgpackEncoder(size_threshold=256)
decoder = MsgpackDecoder(MyType)
encoded = encoder.encode(obj)
@ -78,6 +84,97 @@ def test_encode_decode():
assert_equal(decoded2, obj)
class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]]
def test_multimodal_kwargs():
d = {
"foo":
torch.zeros(20000, dtype=torch.float16),
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
"baz": [
torch.rand((256), dtype=torch.float16),
[
torch.rand((1, 12), dtype=torch.float32),
torch.rand((3, 5, 7), dtype=torch.float64),
], [torch.rand((4, 4), dtype=torch.float16)]
],
}
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest(mm=[MultiModalKwargs(d)])
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)
encoded = encoder.encode(req)
assert len(encoded) == 6
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 44536, +-20 for minor changes
assert total_len >= 44516 and total_len <= 44556
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)
def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
dtype=torch.int16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalBatchedField(),
)
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
dtype=torch.int32),
MultiModalSharedField(4))
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
dtype=torch.int32),
MultiModalBatchedField())
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs.from_items([audio, video, image])
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)
encoded = encoder.encode(req)
assert len(encoded) == 8
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14255, +-20 for minor changes
assert total_len >= 14235 and total_len <= 14275
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
# check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3
images = decoded.get_items("image")
assert len(images) == 1
assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"]
# check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
def nested_equal(a: NestedTensors, b: NestedTensors):
if isinstance(a, torch.Tensor):
return torch.equal(a, b)
else:
return all(nested_equal(x, y) for x, y in zip(a, b))
def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.tensor1, obj2.tensor1)
assert obj1.a_string == obj2.a_string

View File

@ -107,6 +107,7 @@ if TYPE_CHECKING:
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
def get_default_cache_root():
@ -704,6 +705,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB":
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
# Control the threshold for msgspec to use 'zero copy' for
# serialization/deserialization of tensors. Tensors below
# this limit will be encoded into the msgpack buffer, and
# tensors above will instead be sent via a separate message.
# While the sending side still actually copies the tensor
# in all cases, on the receiving side, tensors above this
# limit will actually be zero-copy decoded.
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
}
# end-env-vars-definition

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import dataclasses
import pickle
from collections.abc import Sequence
from inspect import isclass
@ -12,12 +13,26 @@ import torch
import zmq
from msgspec import msgpack
from vllm import envs
from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalBatchedField,
MultiModalFieldConfig, MultiModalFieldElem,
MultiModalFlatField, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3
# TODO calibrate this size
MIN_NOCOPY_BUF_SIZE = 512
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
MultiModalFlatField: "flat",
MultiModalSharedField: "shared",
MultiModalBatchedField: "batched",
}
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
@ -27,14 +42,20 @@ class MsgpackEncoder:
Note that unlike vanilla `msgspec` Encoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
By default, arrays below 256B are serialized inline Larger will get sent
via dedicated messages. Note that this is a per-tensor limit.
"""
def __init__(self):
def __init__(self, size_threshold: Optional[int] = None):
if size_threshold is None:
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self.aux_buffers: Optional[list[bytestr]] = None
self.size_threshold = size_threshold
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
@ -65,6 +86,25 @@ class MsgpackEncoder:
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)
if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj
if not mm.modalities:
# just return the main dict if there are no modalities.
return dict(mm)
# ignore the main dict, it will be re-indexed.
# Encode a list of MultiModalKwargsItems as plain dicts
# + special handling for .field.
# Any tensors *not* indexed by modality will be ignored.
return [[{
"modality": elem.modality,
"key": elem.key,
"data": self._encode_nested_tensors(elem.data),
"field": self._encode_mm_field(elem.field),
} for elem in item.values()]
for itemlist in mm._items_by_modality.values()
for item in itemlist]
if isinstance(obj, FunctionType):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
@ -77,8 +117,9 @@ class MsgpackEncoder:
self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
# If the array is non-contiguous, we need to copy it first
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
if not obj.shape or obj.nbytes < self.size_threshold:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
@ -92,6 +133,26 @@ class MsgpackEncoder:
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor):
return self._encode_ndarray(nt.numpy())
if isinstance(nt, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs
# values are sometimes floats.
return nt
return [self._encode_nested_tensors(x) for x in nt]
def _encode_mm_field(self, field: BaseMultiModalField):
# Figure out the factory name for the field type.
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
if not name:
raise TypeError(f"Unsupported field type: {field.__class__}")
# We just need to copy all of the field values in order
# which will be then used to reconstruct the field.
field_values = (getattr(field, f.name)
for f in dataclasses.fields(field))
return name, *field_values
class MsgpackDecoder:
"""Decoder with custom torch tensor and numpy array serialization.
@ -126,13 +187,50 @@ class MsgpackDecoder:
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
self._decode_mm_items(obj))
return MultiModalKwargs({
k: self._decode_nested_tensors(v)
for k, v in obj.items()
})
return obj
def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
buffer = self.aux_buffers[data] if isinstance(data, int) else data
# Copy from inline representation, otherwise Torch is unhappy since
# the returned memory is non-writeable.
buffer = self.aux_buffers[data] if isinstance(data, int) \
else bytearray(data)
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = []
for item in obj:
elems = []
for v in item:
v["data"] = self._decode_nested_tensors(v["data"])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = v["field"]
factory_meth = getattr(MultiModalFieldConfig,
factory_meth_name)
v["field"] = factory_meth(None, *field_args).field
elems.append(MultiModalFieldElem(**v))
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
return decoded_items
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if isinstance(obj, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs
# values are sometimes floats.
return obj
if not isinstance(obj, list):
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
if obj and isinstance(obj[0], str):
return torch.from_numpy(self._decode_ndarray(obj))
return [self._decode_nested_tensors(x) for x in obj]
def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data