[Bugfix] Allow ScalarType to be compiled with pytorch 2.3 and add checks for registering FakeScalarType and dynamo support. (#7886)
This commit is contained in:
parent
bc6e42a9b1
commit
c166e7e43e
@ -387,7 +387,8 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
// This needs to be implemented and throw a TypeError in order for
|
||||
// PyTorch's opcheck to work on ops that use ScalarTypes.
|
||||
int64_t len() const {
|
||||
throw c10::TypeError("__len__ not implemented");
|
||||
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
|
||||
"__len__ not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -181,92 +181,98 @@ elif core_C_available:
|
||||
|
||||
ScalarType = torch.classes._core_C.ScalarType
|
||||
|
||||
# Needed for dynamo support of ScalarType.
|
||||
@torch._library.register_fake_class("_core_C::ScalarType")
|
||||
class FakeScalarType:
|
||||
if (hasattr(torch, "_library")
|
||||
and hasattr(torch._library, "register_fake_class")):
|
||||
# Needed for dynamo support of ScalarType.
|
||||
@torch._library.register_fake_class("_core_C::ScalarType")
|
||||
class FakeScalarType:
|
||||
|
||||
def __init__(self, scalar_type):
|
||||
self.ScalarType = scalar_type
|
||||
def __init__(self, scalar_type):
|
||||
self.ScalarType = scalar_type
|
||||
|
||||
def bias_getter(self) -> int:
|
||||
return self.ScalarType.bias
|
||||
def bias_getter(self) -> int:
|
||||
return self.ScalarType.bias
|
||||
|
||||
def exponent_getter(self) -> int:
|
||||
return self.ScalarType.exponent
|
||||
def exponent_getter(self) -> int:
|
||||
return self.ScalarType.exponent
|
||||
|
||||
def mantissa_getter(self) -> int:
|
||||
return self.ScalarType.mantissa
|
||||
def mantissa_getter(self) -> int:
|
||||
return self.ScalarType.mantissa
|
||||
|
||||
def signed_getter(self) -> bool:
|
||||
return self.ScalarType.signed
|
||||
def signed_getter(self) -> bool:
|
||||
return self.ScalarType.signed
|
||||
|
||||
def size_bits_getter(self) -> int:
|
||||
return self.ScalarType.size_bits
|
||||
def size_bits_getter(self) -> int:
|
||||
return self.ScalarType.size_bits
|
||||
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.ScalarType.size_bits
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.ScalarType.size_bits
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
return self.ScalarType.min()
|
||||
def min(self) -> Union[int, float]:
|
||||
return self.ScalarType.min()
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
return self.ScalarType.max()
|
||||
def max(self) -> Union[int, float]:
|
||||
return self.ScalarType.max()
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
return self.ScalarType.is_signed()
|
||||
def is_signed(self) -> bool:
|
||||
return self.ScalarType.is_signed()
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
return self.ScalarType.is_floating_point()
|
||||
def is_floating_point(self) -> bool:
|
||||
return self.ScalarType.is_floating_point()
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
return self.ScalarType.is_integer()
|
||||
def is_integer(self) -> bool:
|
||||
return self.ScalarType.is_integer()
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
return self.ScalarType.has_bias()
|
||||
def has_bias(self) -> bool:
|
||||
return self.ScalarType.has_bias()
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
return self.ScalarType.has_infs()
|
||||
def has_infs(self) -> bool:
|
||||
return self.ScalarType.has_infs()
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.ScalarType.has_nans()
|
||||
def has_nans(self) -> bool:
|
||||
return self.ScalarType.has_nans()
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
return self.ScalarType.is_ieee_754()
|
||||
def is_ieee_754(self) -> bool:
|
||||
return self.ScalarType.is_ieee_754()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.ScalarType.__str__()
|
||||
def __str__(self) -> str:
|
||||
return self.ScalarType.__str__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.ScalarType.__repr__()
|
||||
def __repr__(self) -> str:
|
||||
return self.ScalarType.__repr__()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.ScalarType.__len__()
|
||||
def __len__(self) -> int:
|
||||
return self.ScalarType.__len__()
|
||||
|
||||
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
|
||||
return torch.classes._core_C.ScalarType.__obj_flatten__(
|
||||
self.ScalarType)
|
||||
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
|
||||
return torch.classes._core_C.ScalarType.__obj_flatten__(
|
||||
self.ScalarType)
|
||||
|
||||
@classmethod
|
||||
def __obj_unflatten__(
|
||||
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
|
||||
return cls(
|
||||
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))
|
||||
@classmethod
|
||||
def __obj_unflatten__(
|
||||
cls, flat_type: Tuple[Tuple[str, Any],
|
||||
...]) -> 'ScalarType':
|
||||
return cls(
|
||||
torch.classes._core_C.ScalarType.__obj_unflatten__(
|
||||
flat_type))
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
return ScalarType.int_(size_bits, bias)
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
return ScalarType.int_(size_bits, bias)
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
return ScalarType.uint(size_bits, bias)
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
return ScalarType.uint(size_bits, bias)
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
||||
return ScalarType.float_IEEE754(exponent, mantissa)
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int,
|
||||
mantissa: int) -> 'ScalarType':
|
||||
return ScalarType.float_IEEE754(exponent, mantissa)
|
||||
|
||||
@classmethod
|
||||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||
nan_repr: int) -> 'ScalarType':
|
||||
return ScalarType.float_(exponent, mantissa, finite_values_only,
|
||||
nan_repr)
|
||||
@classmethod
|
||||
def float_(cls, exponent: int, mantissa: int,
|
||||
finite_values_only: bool,
|
||||
nan_repr: int) -> 'ScalarType':
|
||||
return ScalarType.float_(exponent, mantissa,
|
||||
finite_values_only, nan_repr)
|
||||
|
@ -25,6 +25,7 @@ import numpy.typing as npt
|
||||
import psutil
|
||||
import torch
|
||||
import torch.types
|
||||
from packaging.version import Version
|
||||
from typing_extensions import ParamSpec, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -1114,3 +1115,11 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
||||
"""Utility function to run async task in a lock"""
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
||||
# In particular, the FakeScalarType is not supported for earlier versions of
|
||||
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
||||
def supports_dynamo() -> bool:
|
||||
base_torch_version = Version(Version(torch.__version__).base_version)
|
||||
return base_torch_version >= Version("2.4.0")
|
||||
|
@ -44,7 +44,8 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
|
||||
flatten_2d_lists, is_hip, is_pin_memory_available)
|
||||
flatten_2d_lists, is_hip, is_pin_memory_available,
|
||||
supports_dynamo)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
@ -946,7 +947,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"provided. Defaulting to scaling factors of 1.0. "
|
||||
"This may lead to less accurate results!")
|
||||
|
||||
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
|
||||
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
|
||||
self.model = torch.compile(self.model,
|
||||
fullgraph=True,
|
||||
backend="eager")
|
||||
|
Loading…
x
Reference in New Issue
Block a user