[V1] Zero-copy tensor/ndarray serialization/transmission (#13790)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
daefed052c
commit
dd143ef541
80
tests/v1/test_serial_utils.py
Normal file
80
tests/v1/test_serial_utils.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from collections import UserDict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class UnrecognizedType(UserDict):
|
||||||
|
|
||||||
|
def __init__(self, an_int: int):
|
||||||
|
super().__init__()
|
||||||
|
self.an_int = an_int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MyType:
|
||||||
|
tensor1: torch.Tensor
|
||||||
|
a_string: str
|
||||||
|
list_of_tensors: list[torch.Tensor]
|
||||||
|
numpy_array: np.ndarray
|
||||||
|
unrecognized: UnrecognizedType
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_decode():
|
||||||
|
"""Test encode/decode loop with zero-copy tensors."""
|
||||||
|
|
||||||
|
obj = MyType(
|
||||||
|
tensor1=torch.randint(low=0,
|
||||||
|
high=100,
|
||||||
|
size=(1024, ),
|
||||||
|
dtype=torch.int32),
|
||||||
|
a_string="hello",
|
||||||
|
list_of_tensors=[
|
||||||
|
torch.rand((1, 10), dtype=torch.float32),
|
||||||
|
torch.rand((3, 5, 4000), dtype=torch.float64),
|
||||||
|
torch.tensor(1984), # test scalar too
|
||||||
|
],
|
||||||
|
numpy_array=np.arange(512),
|
||||||
|
unrecognized=UnrecognizedType(33),
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder = MsgpackEncoder()
|
||||||
|
decoder = MsgpackDecoder(MyType)
|
||||||
|
|
||||||
|
encoded = encoder.encode(obj)
|
||||||
|
|
||||||
|
# There should be the main buffer + 2 large tensor buffers
|
||||||
|
# + 1 large numpy array. "large" is <= 256 bytes.
|
||||||
|
# The two small tensors are encoded inline.
|
||||||
|
assert len(encoded) == 4
|
||||||
|
|
||||||
|
decoded: MyType = decoder.decode(encoded)
|
||||||
|
|
||||||
|
assert_equal(decoded, obj)
|
||||||
|
|
||||||
|
# Test encode_into case
|
||||||
|
|
||||||
|
preallocated = bytearray()
|
||||||
|
|
||||||
|
encoded2 = encoder.encode_into(obj, preallocated)
|
||||||
|
|
||||||
|
assert len(encoded2) == 4
|
||||||
|
assert encoded2[0] is preallocated
|
||||||
|
|
||||||
|
decoded2: MyType = decoder.decode(encoded2)
|
||||||
|
|
||||||
|
assert_equal(decoded2, obj)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_equal(obj1: MyType, obj2: MyType):
|
||||||
|
assert torch.equal(obj1.tensor1, obj2.tensor1)
|
||||||
|
assert obj1.a_string == obj2.a_string
|
||||||
|
assert all(
|
||||||
|
torch.equal(a, b)
|
||||||
|
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
|
||||||
|
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
|
||||||
|
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
|
@ -490,14 +490,14 @@ class EngineCoreProc(EngineCore):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
# (RequestType, RequestData)
|
# (RequestType, RequestData)
|
||||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
type_frame, *data_frames = socket.recv_multipart(copy=False)
|
||||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||||
|
|
||||||
# Deserialize the request data.
|
# Deserialize the request data.
|
||||||
decoder = add_request_decoder if (
|
decoder = add_request_decoder if (
|
||||||
request_type
|
request_type
|
||||||
== EngineCoreRequestType.ADD) else generic_decoder
|
== EngineCoreRequestType.ADD) else generic_decoder
|
||||||
request = decoder.decode(data_frame.buffer)
|
request = decoder.decode(data_frames)
|
||||||
|
|
||||||
# Push to input queue for core busy loop.
|
# Push to input queue for core busy loop.
|
||||||
self.input_queue.put_nowait((request_type, request))
|
self.input_queue.put_nowait((request_type, request))
|
||||||
@ -514,8 +514,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
while True:
|
while True:
|
||||||
outputs = self.output_queue.get()
|
outputs = self.output_queue.get()
|
||||||
outputs.engine_index = engine_index
|
outputs.engine_index = engine_index
|
||||||
encoder.encode_into(outputs, buffer)
|
buffers = encoder.encode_into(outputs, buffer)
|
||||||
socket.send(buffer, copy=False)
|
socket.send_multipart(buffers, copy=False)
|
||||||
|
|
||||||
|
|
||||||
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
|
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
|
||||||
|
@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
|||||||
EngineCoreRequestType, UtilityOutput)
|
EngineCoreRequestType, UtilityOutput)
|
||||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||||
from vllm.v1.utils import BackgroundProcHandle
|
from vllm.v1.utils import BackgroundProcHandle
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -505,8 +505,8 @@ class SyncMPClient(MPClient):
|
|||||||
# shutdown signal, exit thread.
|
# shutdown signal, exit thread.
|
||||||
break
|
break
|
||||||
|
|
||||||
frame = out_socket.recv(copy=False)
|
frames = out_socket.recv_multipart(copy=False)
|
||||||
outputs = decoder.decode(frame.buffer)
|
outputs = decoder.decode(frames)
|
||||||
if outputs.utility_output:
|
if outputs.utility_output:
|
||||||
_process_utility_output(outputs.utility_output,
|
_process_utility_output(outputs.utility_output,
|
||||||
utility_results)
|
utility_results)
|
||||||
@ -529,7 +529,7 @@ class SyncMPClient(MPClient):
|
|||||||
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
||||||
# (Identity, RequestType, SerializedRequest)
|
# (Identity, RequestType, SerializedRequest)
|
||||||
msg = (self.core_engine.identity, request_type.value,
|
msg = (self.core_engine.identity, request_type.value,
|
||||||
self.encoder.encode(request))
|
*self.encoder.encode(request))
|
||||||
self.input_socket.send_multipart(msg, copy=False)
|
self.input_socket.send_multipart(msg, copy=False)
|
||||||
|
|
||||||
def call_utility(self, method: str, *args) -> Any:
|
def call_utility(self, method: str, *args) -> Any:
|
||||||
@ -633,8 +633,8 @@ class AsyncMPClient(MPClient):
|
|||||||
|
|
||||||
async def process_outputs_socket():
|
async def process_outputs_socket():
|
||||||
while True:
|
while True:
|
||||||
(frame, ) = await output_socket.recv_multipart(copy=False)
|
frames = await output_socket.recv_multipart(copy=False)
|
||||||
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
|
outputs: EngineCoreOutputs = decoder.decode(frames)
|
||||||
if outputs.utility_output:
|
if outputs.utility_output:
|
||||||
_process_utility_output(outputs.utility_output,
|
_process_utility_output(outputs.utility_output,
|
||||||
utility_results)
|
utility_results)
|
||||||
@ -666,12 +666,12 @@ class AsyncMPClient(MPClient):
|
|||||||
if engine is None:
|
if engine is None:
|
||||||
engine = self.core_engine
|
engine = self.core_engine
|
||||||
|
|
||||||
message = (request_type.value, self.encoder.encode(request))
|
message = (request_type.value, *self.encoder.encode(request))
|
||||||
return self._send_input_message(message, engine)
|
return self._send_input_message(message, engine)
|
||||||
|
|
||||||
def _send_input_message(self, message: tuple[bytes, bytes],
|
def _send_input_message(self, message: tuple[bytestr, ...],
|
||||||
engine: CoreEngine) -> Awaitable[None]:
|
engine: CoreEngine) -> Awaitable[None]:
|
||||||
message = (engine.identity, ) + message # type: ignore[assignment]
|
message = (engine.identity, ) + message
|
||||||
return self.input_socket.send_multipart(message, copy=False)
|
return self.input_socket.send_multipart(message, copy=False)
|
||||||
|
|
||||||
async def call_utility_async(self, method: str, *args) -> Any:
|
async def call_utility_async(self, method: str, *args) -> Any:
|
||||||
@ -684,8 +684,8 @@ class AsyncMPClient(MPClient):
|
|||||||
call_id = uuid.uuid1().int >> 64
|
call_id = uuid.uuid1().int >> 64
|
||||||
future = asyncio.get_running_loop().create_future()
|
future = asyncio.get_running_loop().create_future()
|
||||||
self.utility_results[call_id] = future
|
self.utility_results[call_id] = future
|
||||||
message = (EngineCoreRequestType.UTILITY.value,
|
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
||||||
self.encoder.encode((call_id, method, args)))
|
(call_id, method, args)))
|
||||||
await self._send_input_message(message, engine)
|
await self._send_input_message(message, engine)
|
||||||
self._ensure_output_queue_task()
|
self._ensure_output_queue_task()
|
||||||
return await future
|
return await future
|
||||||
@ -760,7 +760,7 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
|
|
||||||
# Control message used for triggering dp idle mode loop.
|
# Control message used for triggering dp idle mode loop.
|
||||||
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
|
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
|
||||||
self.encoder.encode(None))
|
*self.encoder.encode(None))
|
||||||
|
|
||||||
self.num_engines_running = 0
|
self.num_engines_running = 0
|
||||||
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
||||||
@ -794,7 +794,7 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
# tokenized.
|
# tokenized.
|
||||||
request.prompt = None
|
request.prompt = None
|
||||||
|
|
||||||
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
|
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
|
||||||
|
|
||||||
chosen_engine = self.get_core_engine_for_request()
|
chosen_engine = self.get_core_engine_for_request()
|
||||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||||
|
@ -1,61 +1,140 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from inspect import isclass
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import zmq
|
||||||
from msgspec import msgpack
|
from msgspec import msgpack
|
||||||
|
|
||||||
CUSTOM_TYPE_TENSOR = 1
|
CUSTOM_TYPE_PICKLE = 1
|
||||||
CUSTOM_TYPE_PICKLE = 2
|
CUSTOM_TYPE_CLOUDPICKLE = 2
|
||||||
CUSTOM_TYPE_CLOUDPICKLE = 3
|
|
||||||
|
# TODO calibrate this size
|
||||||
|
INLINE_BUF_SIZE_THRESHOLD = 256
|
||||||
|
|
||||||
|
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
|
||||||
|
|
||||||
|
|
||||||
class MsgpackEncoder:
|
class MsgpackEncoder:
|
||||||
"""Encoder with custom torch tensor serialization."""
|
"""Encoder with custom torch tensor and numpy array serialization.
|
||||||
|
|
||||||
|
Note that unlike vanilla `msgspec` Encoders, this interface is generally
|
||||||
|
not thread-safe when encoding tensors / numpy arrays.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
|
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
||||||
|
# This is used as a local stash of buffers that we can then access from
|
||||||
|
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
|
||||||
|
# pass custom data to the hook otherwise.
|
||||||
|
self.aux_buffers: Optional[list[bytestr]] = None
|
||||||
|
|
||||||
def encode(self, obj: Any) -> bytes:
|
def encode(self, obj: Any) -> Sequence[bytestr]:
|
||||||
return self.encoder.encode(obj)
|
try:
|
||||||
|
self.aux_buffers = bufs = [b'']
|
||||||
|
bufs[0] = self.encoder.encode(obj)
|
||||||
|
# This `bufs` list allows us to collect direct pointers to backing
|
||||||
|
# buffers of tensors and np arrays, and return them along with the
|
||||||
|
# top-level encoded buffer instead of copying their data into the
|
||||||
|
# new buffer.
|
||||||
|
return bufs
|
||||||
|
finally:
|
||||||
|
self.aux_buffers = None
|
||||||
|
|
||||||
def encode_into(self, obj: Any, buf: bytearray) -> None:
|
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
|
||||||
|
try:
|
||||||
|
self.aux_buffers = [buf]
|
||||||
|
bufs = self.aux_buffers
|
||||||
self.encoder.encode_into(obj, buf)
|
self.encoder.encode_into(obj, buf)
|
||||||
|
return bufs
|
||||||
|
finally:
|
||||||
|
self.aux_buffers = None
|
||||||
|
|
||||||
|
def enc_hook(self, obj: Any) -> Any:
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
return self._encode_ndarray(obj.numpy())
|
||||||
|
|
||||||
|
# Fall back to pickle for object or void kind ndarrays.
|
||||||
|
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
||||||
|
return self._encode_ndarray(obj)
|
||||||
|
|
||||||
|
if isinstance(obj, FunctionType):
|
||||||
|
# `pickle` is generally faster than cloudpickle, but can have
|
||||||
|
# problems serializing methods.
|
||||||
|
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
|
||||||
|
|
||||||
|
return msgpack.Ext(CUSTOM_TYPE_PICKLE,
|
||||||
|
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
|
||||||
|
|
||||||
|
def _encode_ndarray(
|
||||||
|
self, obj: np.ndarray
|
||||||
|
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||||
|
assert self.aux_buffers is not None
|
||||||
|
if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD:
|
||||||
|
# Encode small arrays and scalars inline.
|
||||||
|
data = obj.data
|
||||||
|
else:
|
||||||
|
# Otherwise encode index of backing buffer.
|
||||||
|
obj = np.ascontiguousarray(obj)
|
||||||
|
data = len(self.aux_buffers)
|
||||||
|
self.aux_buffers.append(obj.data)
|
||||||
|
# We serialize the ndarray as a tuple of native types.
|
||||||
|
# The data is either inlined if small, or an index into a list of
|
||||||
|
# backing buffers that we've stashed in `aux_buffers`.
|
||||||
|
return obj.dtype.str, obj.shape, data
|
||||||
|
|
||||||
|
|
||||||
class MsgpackDecoder:
|
class MsgpackDecoder:
|
||||||
"""Decoder with custom torch tensor serialization."""
|
"""Decoder with custom torch tensor and numpy array serialization.
|
||||||
|
|
||||||
|
Note that unlike vanilla `msgspec` Decoders, this interface is generally
|
||||||
|
not thread-safe when encoding tensors / numpy arrays.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, t: Optional[Any] = None):
|
def __init__(self, t: Optional[Any] = None):
|
||||||
args = () if t is None else (t, )
|
args = () if t is None else (t, )
|
||||||
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
|
self.decoder = msgpack.Decoder(*args,
|
||||||
|
ext_hook=self.ext_hook,
|
||||||
|
dec_hook=self.dec_hook)
|
||||||
|
self.aux_buffers: Sequence[bytestr] = ()
|
||||||
|
|
||||||
def decode(self, obj: Any):
|
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
|
||||||
return self.decoder.decode(obj)
|
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
|
||||||
|
# TODO - This check can become `isinstance(bufs, bytestr)`
|
||||||
|
# as of Python 3.10.
|
||||||
|
return self.decoder.decode(bufs)
|
||||||
|
|
||||||
|
self.aux_buffers = bufs
|
||||||
|
try:
|
||||||
|
return self.decoder.decode(bufs[0])
|
||||||
|
finally:
|
||||||
|
self.aux_buffers = ()
|
||||||
|
|
||||||
def custom_enc_hook(obj: Any) -> Any:
|
def dec_hook(self, t: type, obj: Any) -> Any:
|
||||||
if isinstance(obj, torch.Tensor):
|
# Given native types in `obj`, convert to type `t`.
|
||||||
# NOTE(rob): it is fastest to use numpy + pickle
|
if isclass(t):
|
||||||
# when serializing torch tensors.
|
if issubclass(t, np.ndarray):
|
||||||
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
|
return self._decode_ndarray(obj)
|
||||||
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
|
if issubclass(t, torch.Tensor):
|
||||||
|
return torch.from_numpy(self._decode_ndarray(obj))
|
||||||
|
return obj
|
||||||
|
|
||||||
if isinstance(obj, FunctionType):
|
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
||||||
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
|
dtype, shape, data = arr
|
||||||
|
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
||||||
|
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
|
||||||
|
|
||||||
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
|
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||||
|
|
||||||
|
|
||||||
def custom_ext_hook(code: int, data: memoryview) -> Any:
|
|
||||||
if code == CUSTOM_TYPE_TENSOR:
|
|
||||||
return torch.from_numpy(pickle.loads(data))
|
|
||||||
if code == CUSTOM_TYPE_PICKLE:
|
if code == CUSTOM_TYPE_PICKLE:
|
||||||
return pickle.loads(data)
|
return pickle.loads(data)
|
||||||
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||||
return cloudpickle.loads(data)
|
return cloudpickle.loads(data)
|
||||||
|
|
||||||
raise NotImplementedError(f"Extension type code {code} is not supported")
|
raise NotImplementedError(
|
||||||
|
f"Extension type code {code} is not supported")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user