[V1] Zero-copy tensor/ndarray serialization/transmission (#13790)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
daefed052c
commit
dd143ef541
80
tests/v1/test_serial_utils.py
Normal file
80
tests/v1/test_serial_utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
|
||||
class UnrecognizedType(UserDict):
|
||||
|
||||
def __init__(self, an_int: int):
|
||||
super().__init__()
|
||||
self.an_int = an_int
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyType:
|
||||
tensor1: torch.Tensor
|
||||
a_string: str
|
||||
list_of_tensors: list[torch.Tensor]
|
||||
numpy_array: np.ndarray
|
||||
unrecognized: UnrecognizedType
|
||||
|
||||
|
||||
def test_encode_decode():
|
||||
"""Test encode/decode loop with zero-copy tensors."""
|
||||
|
||||
obj = MyType(
|
||||
tensor1=torch.randint(low=0,
|
||||
high=100,
|
||||
size=(1024, ),
|
||||
dtype=torch.int32),
|
||||
a_string="hello",
|
||||
list_of_tensors=[
|
||||
torch.rand((1, 10), dtype=torch.float32),
|
||||
torch.rand((3, 5, 4000), dtype=torch.float64),
|
||||
torch.tensor(1984), # test scalar too
|
||||
],
|
||||
numpy_array=np.arange(512),
|
||||
unrecognized=UnrecognizedType(33),
|
||||
)
|
||||
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder(MyType)
|
||||
|
||||
encoded = encoder.encode(obj)
|
||||
|
||||
# There should be the main buffer + 2 large tensor buffers
|
||||
# + 1 large numpy array. "large" is <= 256 bytes.
|
||||
# The two small tensors are encoded inline.
|
||||
assert len(encoded) == 4
|
||||
|
||||
decoded: MyType = decoder.decode(encoded)
|
||||
|
||||
assert_equal(decoded, obj)
|
||||
|
||||
# Test encode_into case
|
||||
|
||||
preallocated = bytearray()
|
||||
|
||||
encoded2 = encoder.encode_into(obj, preallocated)
|
||||
|
||||
assert len(encoded2) == 4
|
||||
assert encoded2[0] is preallocated
|
||||
|
||||
decoded2: MyType = decoder.decode(encoded2)
|
||||
|
||||
assert_equal(decoded2, obj)
|
||||
|
||||
|
||||
def assert_equal(obj1: MyType, obj2: MyType):
|
||||
assert torch.equal(obj1.tensor1, obj2.tensor1)
|
||||
assert obj1.a_string == obj2.a_string
|
||||
assert all(
|
||||
torch.equal(a, b)
|
||||
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
|
||||
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
|
||||
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
|
@ -490,14 +490,14 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
||||
type_frame, *data_frames = socket.recv_multipart(copy=False)
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frame.buffer)
|
||||
request = decoder.decode(data_frames)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
@ -514,8 +514,8 @@ class EngineCoreProc(EngineCore):
|
||||
while True:
|
||||
outputs = self.output_queue.get()
|
||||
outputs.engine_index = engine_index
|
||||
encoder.encode_into(outputs, buffer)
|
||||
socket.send(buffer, copy=False)
|
||||
buffers = encoder.encode_into(outputs, buffer)
|
||||
socket.send_multipart(buffers, copy=False)
|
||||
|
||||
|
||||
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
|
||||
|
@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||
from vllm.v1.utils import BackgroundProcHandle
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -505,8 +505,8 @@ class SyncMPClient(MPClient):
|
||||
# shutdown signal, exit thread.
|
||||
break
|
||||
|
||||
frame = out_socket.recv(copy=False)
|
||||
outputs = decoder.decode(frame.buffer)
|
||||
frames = out_socket.recv_multipart(copy=False)
|
||||
outputs = decoder.decode(frames)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output,
|
||||
utility_results)
|
||||
@ -529,7 +529,7 @@ class SyncMPClient(MPClient):
|
||||
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
||||
# (Identity, RequestType, SerializedRequest)
|
||||
msg = (self.core_engine.identity, request_type.value,
|
||||
self.encoder.encode(request))
|
||||
*self.encoder.encode(request))
|
||||
self.input_socket.send_multipart(msg, copy=False)
|
||||
|
||||
def call_utility(self, method: str, *args) -> Any:
|
||||
@ -633,8 +633,8 @@ class AsyncMPClient(MPClient):
|
||||
|
||||
async def process_outputs_socket():
|
||||
while True:
|
||||
(frame, ) = await output_socket.recv_multipart(copy=False)
|
||||
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
|
||||
frames = await output_socket.recv_multipart(copy=False)
|
||||
outputs: EngineCoreOutputs = decoder.decode(frames)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output,
|
||||
utility_results)
|
||||
@ -666,12 +666,12 @@ class AsyncMPClient(MPClient):
|
||||
if engine is None:
|
||||
engine = self.core_engine
|
||||
|
||||
message = (request_type.value, self.encoder.encode(request))
|
||||
message = (request_type.value, *self.encoder.encode(request))
|
||||
return self._send_input_message(message, engine)
|
||||
|
||||
def _send_input_message(self, message: tuple[bytes, bytes],
|
||||
def _send_input_message(self, message: tuple[bytestr, ...],
|
||||
engine: CoreEngine) -> Awaitable[None]:
|
||||
message = (engine.identity, ) + message # type: ignore[assignment]
|
||||
message = (engine.identity, ) + message
|
||||
return self.input_socket.send_multipart(message, copy=False)
|
||||
|
||||
async def call_utility_async(self, method: str, *args) -> Any:
|
||||
@ -684,8 +684,8 @@ class AsyncMPClient(MPClient):
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[call_id] = future
|
||||
message = (EngineCoreRequestType.UTILITY.value,
|
||||
self.encoder.encode((call_id, method, args)))
|
||||
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
||||
(call_id, method, args)))
|
||||
await self._send_input_message(message, engine)
|
||||
self._ensure_output_queue_task()
|
||||
return await future
|
||||
@ -760,7 +760,7 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
|
||||
# Control message used for triggering dp idle mode loop.
|
||||
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
|
||||
self.encoder.encode(None))
|
||||
*self.encoder.encode(None))
|
||||
|
||||
self.num_engines_running = 0
|
||||
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
||||
@ -794,7 +794,7 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
# tokenized.
|
||||
request.prompt = None
|
||||
|
||||
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
|
||||
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
|
||||
|
||||
chosen_engine = self.get_core_engine_for_request()
|
||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||
|
@ -1,61 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pickle
|
||||
from collections.abc import Sequence
|
||||
from inspect import isclass
|
||||
from types import FunctionType
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
from msgspec import msgpack
|
||||
|
||||
CUSTOM_TYPE_TENSOR = 1
|
||||
CUSTOM_TYPE_PICKLE = 2
|
||||
CUSTOM_TYPE_CLOUDPICKLE = 3
|
||||
CUSTOM_TYPE_PICKLE = 1
|
||||
CUSTOM_TYPE_CLOUDPICKLE = 2
|
||||
|
||||
# TODO calibrate this size
|
||||
INLINE_BUF_SIZE_THRESHOLD = 256
|
||||
|
||||
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
|
||||
|
||||
|
||||
class MsgpackEncoder:
|
||||
"""Encoder with custom torch tensor serialization."""
|
||||
"""Encoder with custom torch tensor and numpy array serialization.
|
||||
|
||||
Note that unlike vanilla `msgspec` Encoders, this interface is generally
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
|
||||
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
||||
# This is used as a local stash of buffers that we can then access from
|
||||
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
|
||||
# pass custom data to the hook otherwise.
|
||||
self.aux_buffers: Optional[list[bytestr]] = None
|
||||
|
||||
def encode(self, obj: Any) -> bytes:
|
||||
return self.encoder.encode(obj)
|
||||
def encode(self, obj: Any) -> Sequence[bytestr]:
|
||||
try:
|
||||
self.aux_buffers = bufs = [b'']
|
||||
bufs[0] = self.encoder.encode(obj)
|
||||
# This `bufs` list allows us to collect direct pointers to backing
|
||||
# buffers of tensors and np arrays, and return them along with the
|
||||
# top-level encoded buffer instead of copying their data into the
|
||||
# new buffer.
|
||||
return bufs
|
||||
finally:
|
||||
self.aux_buffers = None
|
||||
|
||||
def encode_into(self, obj: Any, buf: bytearray) -> None:
|
||||
self.encoder.encode_into(obj, buf)
|
||||
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
|
||||
try:
|
||||
self.aux_buffers = [buf]
|
||||
bufs = self.aux_buffers
|
||||
self.encoder.encode_into(obj, buf)
|
||||
return bufs
|
||||
finally:
|
||||
self.aux_buffers = None
|
||||
|
||||
def enc_hook(self, obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return self._encode_ndarray(obj.numpy())
|
||||
|
||||
# Fall back to pickle for object or void kind ndarrays.
|
||||
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
||||
return self._encode_ndarray(obj)
|
||||
|
||||
if isinstance(obj, FunctionType):
|
||||
# `pickle` is generally faster than cloudpickle, but can have
|
||||
# problems serializing methods.
|
||||
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
|
||||
|
||||
return msgpack.Ext(CUSTOM_TYPE_PICKLE,
|
||||
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
|
||||
|
||||
def _encode_ndarray(
|
||||
self, obj: np.ndarray
|
||||
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||
assert self.aux_buffers is not None
|
||||
if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD:
|
||||
# Encode small arrays and scalars inline.
|
||||
data = obj.data
|
||||
else:
|
||||
# Otherwise encode index of backing buffer.
|
||||
obj = np.ascontiguousarray(obj)
|
||||
data = len(self.aux_buffers)
|
||||
self.aux_buffers.append(obj.data)
|
||||
# We serialize the ndarray as a tuple of native types.
|
||||
# The data is either inlined if small, or an index into a list of
|
||||
# backing buffers that we've stashed in `aux_buffers`.
|
||||
return obj.dtype.str, obj.shape, data
|
||||
|
||||
|
||||
class MsgpackDecoder:
|
||||
"""Decoder with custom torch tensor serialization."""
|
||||
"""Decoder with custom torch tensor and numpy array serialization.
|
||||
|
||||
Note that unlike vanilla `msgspec` Decoders, this interface is generally
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
"""
|
||||
|
||||
def __init__(self, t: Optional[Any] = None):
|
||||
args = () if t is None else (t, )
|
||||
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
|
||||
self.decoder = msgpack.Decoder(*args,
|
||||
ext_hook=self.ext_hook,
|
||||
dec_hook=self.dec_hook)
|
||||
self.aux_buffers: Sequence[bytestr] = ()
|
||||
|
||||
def decode(self, obj: Any):
|
||||
return self.decoder.decode(obj)
|
||||
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
|
||||
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
|
||||
# TODO - This check can become `isinstance(bufs, bytestr)`
|
||||
# as of Python 3.10.
|
||||
return self.decoder.decode(bufs)
|
||||
|
||||
self.aux_buffers = bufs
|
||||
try:
|
||||
return self.decoder.decode(bufs[0])
|
||||
finally:
|
||||
self.aux_buffers = ()
|
||||
|
||||
def custom_enc_hook(obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor):
|
||||
# NOTE(rob): it is fastest to use numpy + pickle
|
||||
# when serializing torch tensors.
|
||||
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
|
||||
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
|
||||
def dec_hook(self, t: type, obj: Any) -> Any:
|
||||
# Given native types in `obj`, convert to type `t`.
|
||||
if isclass(t):
|
||||
if issubclass(t, np.ndarray):
|
||||
return self._decode_ndarray(obj)
|
||||
if issubclass(t, torch.Tensor):
|
||||
return torch.from_numpy(self._decode_ndarray(obj))
|
||||
return obj
|
||||
|
||||
if isinstance(obj, FunctionType):
|
||||
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
|
||||
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
||||
dtype, shape, data = arr
|
||||
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
||||
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
|
||||
|
||||
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
|
||||
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||
if code == CUSTOM_TYPE_PICKLE:
|
||||
return pickle.loads(data)
|
||||
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||
return cloudpickle.loads(data)
|
||||
|
||||
|
||||
def custom_ext_hook(code: int, data: memoryview) -> Any:
|
||||
if code == CUSTOM_TYPE_TENSOR:
|
||||
return torch.from_numpy(pickle.loads(data))
|
||||
if code == CUSTOM_TYPE_PICKLE:
|
||||
return pickle.loads(data)
|
||||
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||
return cloudpickle.loads(data)
|
||||
|
||||
raise NotImplementedError(f"Extension type code {code} is not supported")
|
||||
raise NotImplementedError(
|
||||
f"Extension type code {code} is not supported")
|
||||
|
Loading…
x
Reference in New Issue
Block a user