vllm/vllm/v1/serial_utils.py
wwl2755 94744ba41a
[V1] [Feature] Collective RPC (#15444)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
2025-03-29 03:39:14 -07:00

62 lines
1.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import pickle
from types import FunctionType
from typing import Any, Optional
import cloudpickle
import torch
from msgspec import msgpack
CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_CLOUDPICKLE = 3
class MsgpackEncoder:
"""Encoder with custom torch tensor serialization."""
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
def encode(self, obj: Any) -> bytes:
return self.encoder.encode(obj)
def encode_into(self, obj: Any, buf: bytearray) -> None:
self.encoder.encode_into(obj, buf)
class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""
def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
def decode(self, obj: Any):
return self.decoder.decode(obj)
def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_TENSOR:
return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported")