[Bugfix] Allow ScalarType to be compiled with pytorch 2.3 and add checks for registering FakeScalarType and dynamo support. (#7886)

This commit is contained in:
bnellnm 2024-08-27 23:13:45 -04:00 committed by GitHub
parent bc6e42a9b1
commit c166e7e43e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 84 additions and 67 deletions

View File

@ -387,7 +387,8 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t len() const {
throw c10::TypeError("__len__ not implemented");
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
"__len__ not implemented");
return 0;
}

View File

@ -181,6 +181,8 @@ elif core_C_available:
ScalarType = torch.classes._core_C.ScalarType
if (hasattr(torch, "_library")
and hasattr(torch._library, "register_fake_class")):
# Needed for dynamo support of ScalarType.
@torch._library.register_fake_class("_core_C::ScalarType")
class FakeScalarType:
@ -249,9 +251,11 @@ elif core_C_available:
@classmethod
def __obj_unflatten__(
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
cls, flat_type: Tuple[Tuple[str, Any],
...]) -> 'ScalarType':
return cls(
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))
torch.classes._core_C.ScalarType.__obj_unflatten__(
flat_type))
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
@ -262,11 +266,13 @@ elif core_C_available:
return ScalarType.uint(size_bits, bias)
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
def float_IEEE754(cls, exponent: int,
mantissa: int) -> 'ScalarType':
return ScalarType.float_IEEE754(exponent, mantissa)
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
def float_(cls, exponent: int, mantissa: int,
finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
return ScalarType.float_(exponent, mantissa, finite_values_only,
nan_repr)
return ScalarType.float_(exponent, mantissa,
finite_values_only, nan_repr)

View File

@ -25,6 +25,7 @@ import numpy.typing as npt
import psutil
import torch
import torch.types
from packaging.version import Version
from typing_extensions import ParamSpec, TypeIs, assert_never
import vllm.envs as envs
@ -1114,3 +1115,11 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.
def supports_dynamo() -> bool:
base_torch_version = Version(Version(torch.__version__).base_version)
return base_torch_version >= Version("2.4.0")

View File

@ -44,7 +44,8 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available)
flatten_2d_lists, is_hip, is_pin_memory_available,
supports_dynamo)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
@ -946,7 +947,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
self.model = torch.compile(self.model,
fullgraph=True,
backend="eager")