[V1][Performance] Implement custom serializaton for MultiModalKwargs [Rebased] (#16432)
Signed-off-by: Staszek Pasko <staszek@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
3cd91dc955
commit
3092375e27
@ -1,10 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import msgspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||||
|
MultiModalFieldElem, MultiModalKwargs,
|
||||||
|
MultiModalKwargsItem,
|
||||||
|
MultiModalSharedField, NestedTensors)
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +56,7 @@ def test_encode_decode():
|
|||||||
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder = MsgpackEncoder()
|
encoder = MsgpackEncoder(size_threshold=256)
|
||||||
decoder = MsgpackDecoder(MyType)
|
decoder = MsgpackDecoder(MyType)
|
||||||
|
|
||||||
encoded = encoder.encode(obj)
|
encoded = encoder.encode(obj)
|
||||||
@ -78,6 +84,97 @@ def test_encode_decode():
|
|||||||
assert_equal(decoded2, obj)
|
assert_equal(decoded2, obj)
|
||||||
|
|
||||||
|
|
||||||
|
class MyRequest(msgspec.Struct):
|
||||||
|
mm: Optional[list[MultiModalKwargs]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimodal_kwargs():
|
||||||
|
d = {
|
||||||
|
"foo":
|
||||||
|
torch.zeros(20000, dtype=torch.float16),
|
||||||
|
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
|
||||||
|
"baz": [
|
||||||
|
torch.rand((256), dtype=torch.float16),
|
||||||
|
[
|
||||||
|
torch.rand((1, 12), dtype=torch.float32),
|
||||||
|
torch.rand((3, 5, 7), dtype=torch.float64),
|
||||||
|
], [torch.rand((4, 4), dtype=torch.float16)]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||||
|
req = MyRequest(mm=[MultiModalKwargs(d)])
|
||||||
|
|
||||||
|
encoder = MsgpackEncoder()
|
||||||
|
decoder = MsgpackDecoder(MyRequest)
|
||||||
|
|
||||||
|
encoded = encoder.encode(req)
|
||||||
|
|
||||||
|
assert len(encoded) == 6
|
||||||
|
|
||||||
|
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||||
|
|
||||||
|
# expected total encoding length, should be 44536, +-20 for minor changes
|
||||||
|
assert total_len >= 44516 and total_len <= 44556
|
||||||
|
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||||
|
assert all(nested_equal(d[k], decoded[k]) for k in d)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimodal_items_by_modality():
|
||||||
|
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
|
||||||
|
dtype=torch.int16),
|
||||||
|
MultiModalBatchedField())
|
||||||
|
e2 = MultiModalFieldElem(
|
||||||
|
"video",
|
||||||
|
"v0",
|
||||||
|
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
|
||||||
|
MultiModalBatchedField(),
|
||||||
|
)
|
||||||
|
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
|
||||||
|
dtype=torch.int32),
|
||||||
|
MultiModalSharedField(4))
|
||||||
|
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
|
||||||
|
dtype=torch.int32),
|
||||||
|
MultiModalBatchedField())
|
||||||
|
audio = MultiModalKwargsItem.from_elems([e1])
|
||||||
|
video = MultiModalKwargsItem.from_elems([e2])
|
||||||
|
image = MultiModalKwargsItem.from_elems([e3, e4])
|
||||||
|
mm = MultiModalKwargs.from_items([audio, video, image])
|
||||||
|
|
||||||
|
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||||
|
req = MyRequest([mm])
|
||||||
|
|
||||||
|
encoder = MsgpackEncoder()
|
||||||
|
decoder = MsgpackDecoder(MyRequest)
|
||||||
|
|
||||||
|
encoded = encoder.encode(req)
|
||||||
|
|
||||||
|
assert len(encoded) == 8
|
||||||
|
|
||||||
|
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||||
|
|
||||||
|
# expected total encoding length, should be 14255, +-20 for minor changes
|
||||||
|
assert total_len >= 14235 and total_len <= 14275
|
||||||
|
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||||
|
|
||||||
|
# check all modalities were recovered and do some basic sanity checks
|
||||||
|
assert len(decoded.modalities) == 3
|
||||||
|
images = decoded.get_items("image")
|
||||||
|
assert len(images) == 1
|
||||||
|
assert len(images[0].items()) == 2
|
||||||
|
assert list(images[0].keys()) == ["i0", "i1"]
|
||||||
|
|
||||||
|
# check the tensor contents and layout in the main dict
|
||||||
|
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
|
||||||
|
|
||||||
|
|
||||||
|
def nested_equal(a: NestedTensors, b: NestedTensors):
|
||||||
|
if isinstance(a, torch.Tensor):
|
||||||
|
return torch.equal(a, b)
|
||||||
|
else:
|
||||||
|
return all(nested_equal(x, y) for x, y in zip(a, b))
|
||||||
|
|
||||||
|
|
||||||
def assert_equal(obj1: MyType, obj2: MyType):
|
def assert_equal(obj1: MyType, obj2: MyType):
|
||||||
assert torch.equal(obj1.tensor1, obj2.tensor1)
|
assert torch.equal(obj1.tensor1, obj2.tensor1)
|
||||||
assert obj1.a_string == obj2.a_string
|
assert obj1.a_string == obj2.a_string
|
||||||
|
11
vllm/envs.py
11
vllm/envs.py
@ -107,6 +107,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||||
VLLM_USE_DEEP_GEMM: bool = False
|
VLLM_USE_DEEP_GEMM: bool = False
|
||||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||||
|
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -704,6 +705,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# It can be changed with this variable if needed for some reason.
|
# It can be changed with this variable if needed for some reason.
|
||||||
"VLLM_XGRAMMAR_CACHE_MB":
|
"VLLM_XGRAMMAR_CACHE_MB":
|
||||||
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
|
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
|
||||||
|
|
||||||
|
# Control the threshold for msgspec to use 'zero copy' for
|
||||||
|
# serialization/deserialization of tensors. Tensors below
|
||||||
|
# this limit will be encoded into the msgpack buffer, and
|
||||||
|
# tensors above will instead be sent via a separate message.
|
||||||
|
# While the sending side still actually copies the tensor
|
||||||
|
# in all cases, on the receiving side, tensors above this
|
||||||
|
# limit will actually be zero-copy decoded.
|
||||||
|
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
|
||||||
|
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import pickle
|
import pickle
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from inspect import isclass
|
from inspect import isclass
|
||||||
@ -12,12 +13,26 @@ import torch
|
|||||||
import zmq
|
import zmq
|
||||||
from msgspec import msgpack
|
from msgspec import msgpack
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.multimodal.inputs import (BaseMultiModalField,
|
||||||
|
MultiModalBatchedField,
|
||||||
|
MultiModalFieldConfig, MultiModalFieldElem,
|
||||||
|
MultiModalFlatField, MultiModalKwargs,
|
||||||
|
MultiModalKwargsItem,
|
||||||
|
MultiModalSharedField, NestedTensors)
|
||||||
|
|
||||||
CUSTOM_TYPE_PICKLE = 1
|
CUSTOM_TYPE_PICKLE = 1
|
||||||
CUSTOM_TYPE_CLOUDPICKLE = 2
|
CUSTOM_TYPE_CLOUDPICKLE = 2
|
||||||
CUSTOM_TYPE_RAW_VIEW = 3
|
CUSTOM_TYPE_RAW_VIEW = 3
|
||||||
|
|
||||||
# TODO calibrate this size
|
# MultiModalField class serialization type map.
|
||||||
MIN_NOCOPY_BUF_SIZE = 512
|
# These need to list all possible field types and match them
|
||||||
|
# to factory methods in `MultiModalFieldConfig`.
|
||||||
|
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
|
||||||
|
MultiModalFlatField: "flat",
|
||||||
|
MultiModalSharedField: "shared",
|
||||||
|
MultiModalBatchedField: "batched",
|
||||||
|
}
|
||||||
|
|
||||||
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
|
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
|
||||||
|
|
||||||
@ -27,14 +42,20 @@ class MsgpackEncoder:
|
|||||||
|
|
||||||
Note that unlike vanilla `msgspec` Encoders, this interface is generally
|
Note that unlike vanilla `msgspec` Encoders, this interface is generally
|
||||||
not thread-safe when encoding tensors / numpy arrays.
|
not thread-safe when encoding tensors / numpy arrays.
|
||||||
|
|
||||||
|
By default, arrays below 256B are serialized inline Larger will get sent
|
||||||
|
via dedicated messages. Note that this is a per-tensor limit.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, size_threshold: Optional[int] = None):
|
||||||
|
if size_threshold is None:
|
||||||
|
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
|
||||||
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
||||||
# This is used as a local stash of buffers that we can then access from
|
# This is used as a local stash of buffers that we can then access from
|
||||||
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
|
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
|
||||||
# pass custom data to the hook otherwise.
|
# pass custom data to the hook otherwise.
|
||||||
self.aux_buffers: Optional[list[bytestr]] = None
|
self.aux_buffers: Optional[list[bytestr]] = None
|
||||||
|
self.size_threshold = size_threshold
|
||||||
|
|
||||||
def encode(self, obj: Any) -> Sequence[bytestr]:
|
def encode(self, obj: Any) -> Sequence[bytestr]:
|
||||||
try:
|
try:
|
||||||
@ -65,6 +86,25 @@ class MsgpackEncoder:
|
|||||||
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
||||||
return self._encode_ndarray(obj)
|
return self._encode_ndarray(obj)
|
||||||
|
|
||||||
|
if isinstance(obj, MultiModalKwargs):
|
||||||
|
mm: MultiModalKwargs = obj
|
||||||
|
if not mm.modalities:
|
||||||
|
# just return the main dict if there are no modalities.
|
||||||
|
return dict(mm)
|
||||||
|
|
||||||
|
# ignore the main dict, it will be re-indexed.
|
||||||
|
# Encode a list of MultiModalKwargsItems as plain dicts
|
||||||
|
# + special handling for .field.
|
||||||
|
# Any tensors *not* indexed by modality will be ignored.
|
||||||
|
return [[{
|
||||||
|
"modality": elem.modality,
|
||||||
|
"key": elem.key,
|
||||||
|
"data": self._encode_nested_tensors(elem.data),
|
||||||
|
"field": self._encode_mm_field(elem.field),
|
||||||
|
} for elem in item.values()]
|
||||||
|
for itemlist in mm._items_by_modality.values()
|
||||||
|
for item in itemlist]
|
||||||
|
|
||||||
if isinstance(obj, FunctionType):
|
if isinstance(obj, FunctionType):
|
||||||
# `pickle` is generally faster than cloudpickle, but can have
|
# `pickle` is generally faster than cloudpickle, but can have
|
||||||
# problems serializing methods.
|
# problems serializing methods.
|
||||||
@ -77,8 +117,9 @@ class MsgpackEncoder:
|
|||||||
self, obj: np.ndarray
|
self, obj: np.ndarray
|
||||||
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||||
assert self.aux_buffers is not None
|
assert self.aux_buffers is not None
|
||||||
|
# If the array is non-contiguous, we need to copy it first
|
||||||
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
|
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
|
||||||
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
|
if not obj.shape or obj.nbytes < self.size_threshold:
|
||||||
# Encode small arrays and scalars inline. Using this extension type
|
# Encode small arrays and scalars inline. Using this extension type
|
||||||
# ensures we can avoid copying when decoding.
|
# ensures we can avoid copying when decoding.
|
||||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
|
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
|
||||||
@ -92,6 +133,26 @@ class MsgpackEncoder:
|
|||||||
# backing buffers that we've stashed in `aux_buffers`.
|
# backing buffers that we've stashed in `aux_buffers`.
|
||||||
return obj.dtype.str, obj.shape, data
|
return obj.dtype.str, obj.shape, data
|
||||||
|
|
||||||
|
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
|
||||||
|
if isinstance(nt, torch.Tensor):
|
||||||
|
return self._encode_ndarray(nt.numpy())
|
||||||
|
if isinstance(nt, (int, float)):
|
||||||
|
# Although it violates NestedTensors type, MultiModalKwargs
|
||||||
|
# values are sometimes floats.
|
||||||
|
return nt
|
||||||
|
return [self._encode_nested_tensors(x) for x in nt]
|
||||||
|
|
||||||
|
def _encode_mm_field(self, field: BaseMultiModalField):
|
||||||
|
# Figure out the factory name for the field type.
|
||||||
|
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
|
||||||
|
if not name:
|
||||||
|
raise TypeError(f"Unsupported field type: {field.__class__}")
|
||||||
|
# We just need to copy all of the field values in order
|
||||||
|
# which will be then used to reconstruct the field.
|
||||||
|
field_values = (getattr(field, f.name)
|
||||||
|
for f in dataclasses.fields(field))
|
||||||
|
return name, *field_values
|
||||||
|
|
||||||
|
|
||||||
class MsgpackDecoder:
|
class MsgpackDecoder:
|
||||||
"""Decoder with custom torch tensor and numpy array serialization.
|
"""Decoder with custom torch tensor and numpy array serialization.
|
||||||
@ -126,13 +187,50 @@ class MsgpackDecoder:
|
|||||||
return self._decode_ndarray(obj)
|
return self._decode_ndarray(obj)
|
||||||
if issubclass(t, torch.Tensor):
|
if issubclass(t, torch.Tensor):
|
||||||
return torch.from_numpy(self._decode_ndarray(obj))
|
return torch.from_numpy(self._decode_ndarray(obj))
|
||||||
|
if issubclass(t, MultiModalKwargs):
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return MultiModalKwargs.from_items(
|
||||||
|
self._decode_mm_items(obj))
|
||||||
|
return MultiModalKwargs({
|
||||||
|
k: self._decode_nested_tensors(v)
|
||||||
|
for k, v in obj.items()
|
||||||
|
})
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
||||||
dtype, shape, data = arr
|
dtype, shape, data = arr
|
||||||
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
# Copy from inline representation, otherwise Torch is unhappy since
|
||||||
|
# the returned memory is non-writeable.
|
||||||
|
buffer = self.aux_buffers[data] if isinstance(data, int) \
|
||||||
|
else bytearray(data)
|
||||||
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
|
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
|
||||||
|
|
||||||
|
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
|
||||||
|
decoded_items = []
|
||||||
|
for item in obj:
|
||||||
|
elems = []
|
||||||
|
for v in item:
|
||||||
|
v["data"] = self._decode_nested_tensors(v["data"])
|
||||||
|
# Reconstruct the field processor using MultiModalFieldConfig
|
||||||
|
factory_meth_name, *field_args = v["field"]
|
||||||
|
factory_meth = getattr(MultiModalFieldConfig,
|
||||||
|
factory_meth_name)
|
||||||
|
v["field"] = factory_meth(None, *field_args).field
|
||||||
|
elems.append(MultiModalFieldElem(**v))
|
||||||
|
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
|
||||||
|
return decoded_items
|
||||||
|
|
||||||
|
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
||||||
|
if isinstance(obj, (int, float)):
|
||||||
|
# Although it violates NestedTensors type, MultiModalKwargs
|
||||||
|
# values are sometimes floats.
|
||||||
|
return obj
|
||||||
|
if not isinstance(obj, list):
|
||||||
|
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
|
||||||
|
if obj and isinstance(obj[0], str):
|
||||||
|
return torch.from_numpy(self._decode_ndarray(obj))
|
||||||
|
return [self._decode_nested_tensors(x) for x in obj]
|
||||||
|
|
||||||
def ext_hook(self, code: int, data: memoryview) -> Any:
|
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||||
if code == CUSTOM_TYPE_RAW_VIEW:
|
if code == CUSTOM_TYPE_RAW_VIEW:
|
||||||
return data
|
return data
|
||||||
|
Loading…
x
Reference in New Issue
Block a user