[V1][Performance] Implement custom serializaton for MultiModalKwargs [Rebased] (#16432)
Signed-off-by: Staszek Pasko <staszek@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
3cd91dc955
commit
3092375e27
@ -1,10 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
|
||||
@ -50,7 +56,7 @@ def test_encode_decode():
|
||||
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
||||
)
|
||||
|
||||
encoder = MsgpackEncoder()
|
||||
encoder = MsgpackEncoder(size_threshold=256)
|
||||
decoder = MsgpackDecoder(MyType)
|
||||
|
||||
encoded = encoder.encode(obj)
|
||||
@ -78,6 +84,97 @@ def test_encode_decode():
|
||||
assert_equal(decoded2, obj)
|
||||
|
||||
|
||||
class MyRequest(msgspec.Struct):
|
||||
mm: Optional[list[MultiModalKwargs]]
|
||||
|
||||
|
||||
def test_multimodal_kwargs():
|
||||
d = {
|
||||
"foo":
|
||||
torch.zeros(20000, dtype=torch.float16),
|
||||
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
|
||||
"baz": [
|
||||
torch.rand((256), dtype=torch.float16),
|
||||
[
|
||||
torch.rand((1, 12), dtype=torch.float32),
|
||||
torch.rand((3, 5, 7), dtype=torch.float64),
|
||||
], [torch.rand((4, 4), dtype=torch.float16)]
|
||||
],
|
||||
}
|
||||
|
||||
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||
req = MyRequest(mm=[MultiModalKwargs(d)])
|
||||
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder(MyRequest)
|
||||
|
||||
encoded = encoder.encode(req)
|
||||
|
||||
assert len(encoded) == 6
|
||||
|
||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||
|
||||
# expected total encoding length, should be 44536, +-20 for minor changes
|
||||
assert total_len >= 44516 and total_len <= 44556
|
||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||
assert all(nested_equal(d[k], decoded[k]) for k in d)
|
||||
|
||||
|
||||
def test_multimodal_items_by_modality():
|
||||
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
|
||||
dtype=torch.int16),
|
||||
MultiModalBatchedField())
|
||||
e2 = MultiModalFieldElem(
|
||||
"video",
|
||||
"v0",
|
||||
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
|
||||
MultiModalBatchedField(),
|
||||
)
|
||||
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
|
||||
dtype=torch.int32),
|
||||
MultiModalSharedField(4))
|
||||
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
|
||||
dtype=torch.int32),
|
||||
MultiModalBatchedField())
|
||||
audio = MultiModalKwargsItem.from_elems([e1])
|
||||
video = MultiModalKwargsItem.from_elems([e2])
|
||||
image = MultiModalKwargsItem.from_elems([e3, e4])
|
||||
mm = MultiModalKwargs.from_items([audio, video, image])
|
||||
|
||||
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||
req = MyRequest([mm])
|
||||
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder(MyRequest)
|
||||
|
||||
encoded = encoder.encode(req)
|
||||
|
||||
assert len(encoded) == 8
|
||||
|
||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||
|
||||
# expected total encoding length, should be 14255, +-20 for minor changes
|
||||
assert total_len >= 14235 and total_len <= 14275
|
||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||
|
||||
# check all modalities were recovered and do some basic sanity checks
|
||||
assert len(decoded.modalities) == 3
|
||||
images = decoded.get_items("image")
|
||||
assert len(images) == 1
|
||||
assert len(images[0].items()) == 2
|
||||
assert list(images[0].keys()) == ["i0", "i1"]
|
||||
|
||||
# check the tensor contents and layout in the main dict
|
||||
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
|
||||
|
||||
|
||||
def nested_equal(a: NestedTensors, b: NestedTensors):
|
||||
if isinstance(a, torch.Tensor):
|
||||
return torch.equal(a, b)
|
||||
else:
|
||||
return all(nested_equal(x, y) for x, y in zip(a, b))
|
||||
|
||||
|
||||
def assert_equal(obj1: MyType, obj2: MyType):
|
||||
assert torch.equal(obj1.tensor1, obj2.tensor1)
|
||||
assert obj1.a_string == obj2.a_string
|
||||
|
11
vllm/envs.py
11
vllm/envs.py
@ -107,6 +107,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -704,6 +705,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# It can be changed with this variable if needed for some reason.
|
||||
"VLLM_XGRAMMAR_CACHE_MB":
|
||||
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
|
||||
|
||||
# Control the threshold for msgspec to use 'zero copy' for
|
||||
# serialization/deserialization of tensors. Tensors below
|
||||
# this limit will be encoded into the msgpack buffer, and
|
||||
# tensors above will instead be sent via a separate message.
|
||||
# While the sending side still actually copies the tensor
|
||||
# in all cases, on the receiving side, tensors above this
|
||||
# limit will actually be zero-copy decoded.
|
||||
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
|
||||
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import dataclasses
|
||||
import pickle
|
||||
from collections.abc import Sequence
|
||||
from inspect import isclass
|
||||
@ -12,12 +13,26 @@ import torch
|
||||
import zmq
|
||||
from msgspec import msgpack
|
||||
|
||||
from vllm import envs
|
||||
from vllm.multimodal.inputs import (BaseMultiModalField,
|
||||
MultiModalBatchedField,
|
||||
MultiModalFieldConfig, MultiModalFieldElem,
|
||||
MultiModalFlatField, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
|
||||
CUSTOM_TYPE_PICKLE = 1
|
||||
CUSTOM_TYPE_CLOUDPICKLE = 2
|
||||
CUSTOM_TYPE_RAW_VIEW = 3
|
||||
|
||||
# TODO calibrate this size
|
||||
MIN_NOCOPY_BUF_SIZE = 512
|
||||
# MultiModalField class serialization type map.
|
||||
# These need to list all possible field types and match them
|
||||
# to factory methods in `MultiModalFieldConfig`.
|
||||
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
|
||||
MultiModalFlatField: "flat",
|
||||
MultiModalSharedField: "shared",
|
||||
MultiModalBatchedField: "batched",
|
||||
}
|
||||
|
||||
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
|
||||
|
||||
@ -27,14 +42,20 @@ class MsgpackEncoder:
|
||||
|
||||
Note that unlike vanilla `msgspec` Encoders, this interface is generally
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
|
||||
By default, arrays below 256B are serialized inline Larger will get sent
|
||||
via dedicated messages. Note that this is a per-tensor limit.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, size_threshold: Optional[int] = None):
|
||||
if size_threshold is None:
|
||||
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
|
||||
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
||||
# This is used as a local stash of buffers that we can then access from
|
||||
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
|
||||
# pass custom data to the hook otherwise.
|
||||
self.aux_buffers: Optional[list[bytestr]] = None
|
||||
self.size_threshold = size_threshold
|
||||
|
||||
def encode(self, obj: Any) -> Sequence[bytestr]:
|
||||
try:
|
||||
@ -65,6 +86,25 @@ class MsgpackEncoder:
|
||||
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
||||
return self._encode_ndarray(obj)
|
||||
|
||||
if isinstance(obj, MultiModalKwargs):
|
||||
mm: MultiModalKwargs = obj
|
||||
if not mm.modalities:
|
||||
# just return the main dict if there are no modalities.
|
||||
return dict(mm)
|
||||
|
||||
# ignore the main dict, it will be re-indexed.
|
||||
# Encode a list of MultiModalKwargsItems as plain dicts
|
||||
# + special handling for .field.
|
||||
# Any tensors *not* indexed by modality will be ignored.
|
||||
return [[{
|
||||
"modality": elem.modality,
|
||||
"key": elem.key,
|
||||
"data": self._encode_nested_tensors(elem.data),
|
||||
"field": self._encode_mm_field(elem.field),
|
||||
} for elem in item.values()]
|
||||
for itemlist in mm._items_by_modality.values()
|
||||
for item in itemlist]
|
||||
|
||||
if isinstance(obj, FunctionType):
|
||||
# `pickle` is generally faster than cloudpickle, but can have
|
||||
# problems serializing methods.
|
||||
@ -77,8 +117,9 @@ class MsgpackEncoder:
|
||||
self, obj: np.ndarray
|
||||
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||
assert self.aux_buffers is not None
|
||||
# If the array is non-contiguous, we need to copy it first
|
||||
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
|
||||
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
|
||||
if not obj.shape or obj.nbytes < self.size_threshold:
|
||||
# Encode small arrays and scalars inline. Using this extension type
|
||||
# ensures we can avoid copying when decoding.
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
|
||||
@ -92,6 +133,26 @@ class MsgpackEncoder:
|
||||
# backing buffers that we've stashed in `aux_buffers`.
|
||||
return obj.dtype.str, obj.shape, data
|
||||
|
||||
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
|
||||
if isinstance(nt, torch.Tensor):
|
||||
return self._encode_ndarray(nt.numpy())
|
||||
if isinstance(nt, (int, float)):
|
||||
# Although it violates NestedTensors type, MultiModalKwargs
|
||||
# values are sometimes floats.
|
||||
return nt
|
||||
return [self._encode_nested_tensors(x) for x in nt]
|
||||
|
||||
def _encode_mm_field(self, field: BaseMultiModalField):
|
||||
# Figure out the factory name for the field type.
|
||||
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
|
||||
if not name:
|
||||
raise TypeError(f"Unsupported field type: {field.__class__}")
|
||||
# We just need to copy all of the field values in order
|
||||
# which will be then used to reconstruct the field.
|
||||
field_values = (getattr(field, f.name)
|
||||
for f in dataclasses.fields(field))
|
||||
return name, *field_values
|
||||
|
||||
|
||||
class MsgpackDecoder:
|
||||
"""Decoder with custom torch tensor and numpy array serialization.
|
||||
@ -126,13 +187,50 @@ class MsgpackDecoder:
|
||||
return self._decode_ndarray(obj)
|
||||
if issubclass(t, torch.Tensor):
|
||||
return torch.from_numpy(self._decode_ndarray(obj))
|
||||
if issubclass(t, MultiModalKwargs):
|
||||
if isinstance(obj, list):
|
||||
return MultiModalKwargs.from_items(
|
||||
self._decode_mm_items(obj))
|
||||
return MultiModalKwargs({
|
||||
k: self._decode_nested_tensors(v)
|
||||
for k, v in obj.items()
|
||||
})
|
||||
return obj
|
||||
|
||||
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
||||
dtype, shape, data = arr
|
||||
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
||||
# Copy from inline representation, otherwise Torch is unhappy since
|
||||
# the returned memory is non-writeable.
|
||||
buffer = self.aux_buffers[data] if isinstance(data, int) \
|
||||
else bytearray(data)
|
||||
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
|
||||
|
||||
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
|
||||
decoded_items = []
|
||||
for item in obj:
|
||||
elems = []
|
||||
for v in item:
|
||||
v["data"] = self._decode_nested_tensors(v["data"])
|
||||
# Reconstruct the field processor using MultiModalFieldConfig
|
||||
factory_meth_name, *field_args = v["field"]
|
||||
factory_meth = getattr(MultiModalFieldConfig,
|
||||
factory_meth_name)
|
||||
v["field"] = factory_meth(None, *field_args).field
|
||||
elems.append(MultiModalFieldElem(**v))
|
||||
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
|
||||
return decoded_items
|
||||
|
||||
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
||||
if isinstance(obj, (int, float)):
|
||||
# Although it violates NestedTensors type, MultiModalKwargs
|
||||
# values are sometimes floats.
|
||||
return obj
|
||||
if not isinstance(obj, list):
|
||||
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
|
||||
if obj and isinstance(obj[0], str):
|
||||
return torch.from_numpy(self._decode_ndarray(obj))
|
||||
return [self._decode_nested_tensors(x) for x in obj]
|
||||
|
||||
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||
if code == CUSTOM_TYPE_RAW_VIEW:
|
||||
return data
|
||||
|
Loading…
x
Reference in New Issue
Block a user