[V1] Zero-copy tensor/ndarray serialization/transmission (#13790)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-04-10 12:23:14 -07:00 committed by GitHub
parent daefed052c
commit dd143ef541
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 209 additions and 50 deletions

View File

@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
from collections import UserDict
from dataclasses import dataclass
import numpy as np
import torch
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
class UnrecognizedType(UserDict):
def __init__(self, an_int: int):
super().__init__()
self.an_int = an_int
@dataclass
class MyType:
tensor1: torch.Tensor
a_string: str
list_of_tensors: list[torch.Tensor]
numpy_array: np.ndarray
unrecognized: UnrecognizedType
def test_encode_decode():
"""Test encode/decode loop with zero-copy tensors."""
obj = MyType(
tensor1=torch.randint(low=0,
high=100,
size=(1024, ),
dtype=torch.int32),
a_string="hello",
list_of_tensors=[
torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4000), dtype=torch.float64),
torch.tensor(1984), # test scalar too
],
numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33),
)
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyType)
encoded = encoder.encode(obj)
# There should be the main buffer + 2 large tensor buffers
# + 1 large numpy array. "large" is <= 256 bytes.
# The two small tensors are encoded inline.
assert len(encoded) == 4
decoded: MyType = decoder.decode(encoded)
assert_equal(decoded, obj)
# Test encode_into case
preallocated = bytearray()
encoded2 = encoder.encode_into(obj, preallocated)
assert len(encoded2) == 4
assert encoded2[0] is preallocated
decoded2: MyType = decoder.decode(encoded2)
assert_equal(decoded2, obj)
def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.tensor1, obj2.tensor1)
assert obj1.a_string == obj2.a_string
assert all(
torch.equal(a, b)
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int

View File

@ -490,14 +490,14 @@ class EngineCoreProc(EngineCore):
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
type_frame, *data_frames = socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
@ -514,8 +514,8 @@ class EngineCoreProc(EngineCore):
while True:
outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer)
socket.send(buffer, copy=False)
buffers = encoder.encode_into(outputs, buffer)
socket.send_multipart(buffers, copy=False)
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)

View File

@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__)
@ -505,8 +505,8 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
break
frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer)
frames = out_socket.recv_multipart(copy=False)
outputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
@ -529,7 +529,7 @@ class SyncMPClient(MPClient):
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
self.encoder.encode(request))
*self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
def call_utility(self, method: str, *args) -> Any:
@ -633,8 +633,8 @@ class AsyncMPClient(MPClient):
async def process_outputs_socket():
while True:
(frame, ) = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
frames = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
@ -666,12 +666,12 @@ class AsyncMPClient(MPClient):
if engine is None:
engine = self.core_engine
message = (request_type.value, self.encoder.encode(request))
message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine)
def _send_input_message(self, message: tuple[bytes, bytes],
def _send_input_message(self, message: tuple[bytestr, ...],
engine: CoreEngine) -> Awaitable[None]:
message = (engine.identity, ) + message # type: ignore[assignment]
message = (engine.identity, ) + message
return self.input_socket.send_multipart(message, copy=False)
async def call_utility_async(self, method: str, *args) -> Any:
@ -684,8 +684,8 @@ class AsyncMPClient(MPClient):
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(call_id, method, args)))
await self._send_input_message(message, engine)
self._ensure_output_queue_task()
return await future
@ -760,7 +760,7 @@ class DPAsyncMPClient(AsyncMPClient):
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
*self.encoder.encode(None))
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
@ -794,7 +794,7 @@ class DPAsyncMPClient(AsyncMPClient):
# tokenized.
request.prompt = None
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine

View File

@ -1,61 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
from collections.abc import Sequence
from inspect import isclass
from types import FunctionType
from typing import Any, Optional
from typing import Any, Optional, Union
import cloudpickle
import numpy as np
import torch
import zmq
from msgspec import msgpack
CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_CLOUDPICKLE = 3
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
# TODO calibrate this size
INLINE_BUF_SIZE_THRESHOLD = 256
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
class MsgpackEncoder:
"""Encoder with custom torch tensor serialization."""
"""Encoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Encoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self.aux_buffers: Optional[list[bytestr]] = None
def encode(self, obj: Any) -> bytes:
return self.encoder.encode(obj)
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
self.aux_buffers = bufs = [b'']
bufs[0] = self.encoder.encode(obj)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return bufs
finally:
self.aux_buffers = None
def encode_into(self, obj: Any, buf: bytearray) -> None:
self.encoder.encode_into(obj, buf)
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
try:
self.aux_buffers = [buf]
bufs = self.aux_buffers
self.encoder.encode_into(obj, buf)
return bufs
finally:
self.aux_buffers = None
def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return self._encode_ndarray(obj.numpy())
# Fall back to pickle for object or void kind ndarrays.
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)
if isinstance(obj, FunctionType):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE,
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
def _encode_ndarray(
self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD:
# Encode small arrays and scalars inline.
data = obj.data
else:
# Otherwise encode index of backing buffer.
obj = np.ascontiguousarray(obj)
data = len(self.aux_buffers)
self.aux_buffers.append(obj.data)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data
class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""
"""Decoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
self.decoder = msgpack.Decoder(*args,
ext_hook=self.ext_hook,
dec_hook=self.dec_hook)
self.aux_buffers: Sequence[bytestr] = ()
def decode(self, obj: Any):
return self.decoder.decode(obj)
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
# TODO - This check can become `isinstance(bufs, bytestr)`
# as of Python 3.10.
return self.decoder.decode(bufs)
self.aux_buffers = bufs
try:
return self.decoder.decode(bufs[0])
finally:
self.aux_buffers = ()
def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
def dec_hook(self, t: type, obj: Any) -> Any:
# Given native types in `obj`, convert to type `t`.
if isclass(t):
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
return obj
if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_TENSOR:
return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported")
raise NotImplementedError(
f"Extension type code {code} is not supported")