From 7678fcd5b6d64084aeabeb17251d388e251ba4c9 Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Thu, 10 Apr 2025 07:37:47 -0700 Subject: [PATCH] Fix the torch version parsing logic (#15857) --- vllm/compilation/compiler_interface.py | 5 ++--- vllm/compilation/inductor_pass.py | 6 +++--- vllm/config.py | 8 +++----- vllm/utils.py | 18 ++++++++++++++++++ 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 5a22cf70..6c887591 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -2,7 +2,6 @@ import contextlib import copy import hashlib -import importlib.metadata import os from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Tuple @@ -11,9 +10,9 @@ from unittest.mock import patch import torch import torch._inductor.compile_fx import torch.fx as fx -from packaging.version import Version from vllm.config import VllmConfig +from vllm.utils import is_torch_equal_or_newer class CompilerInterface: @@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface): manually setting up internal contexts. But we also rely on non-public APIs which might not provide these guarantees. """ - if Version(importlib.metadata.version('torch')) >= Version("2.6"): + if is_torch_equal_or_newer("2.6"): import torch._dynamo.utils return torch._dynamo.utils.get_metrics_context() else: diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 08dd8c8e..00a2e89f 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,17 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 import hashlib -import importlib.metadata import inspect import json import types from typing import Any, Callable, Dict, Optional, Union import torch -from packaging.version import Version from torch import fx -if Version(importlib.metadata.version('torch')) >= Version("2.6"): +from vllm.utils import is_torch_equal_or_newer + +if is_torch_equal_or_newer("2.6"): from torch._inductor.custom_graph_pass import CustomGraphPass else: # CustomGraphPass is not present in 2.5 or lower, import our version diff --git a/vllm/config.py b/vllm/config.py index 2662c6a8..5fcc5f46 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4,7 +4,6 @@ import ast import copy import enum import hashlib -import importlib.metadata import json import sys import warnings @@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, Optional, Protocol, Union) import torch -from packaging.version import Version from pydantic import BaseModel, Field, PrivateAttr from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig @@ -40,8 +38,8 @@ from vllm.transformers_utils.config import ( from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, get_open_port, random_uuid, - resolve_obj_by_qualname) + get_cpu_memory, get_open_port, is_torch_equal_or_newer, + random_uuid, resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel): # and it is not yet a priority. RFC here: # https://github.com/vllm-project/vllm/issues/14703 - if Version(importlib.metadata.version('torch')) >= Version("2.6"): + if is_torch_equal_or_newer("2.6"): KEY = 'enable_auto_functionalized_v2' if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False diff --git a/vllm/utils.py b/vllm/utils.py index 1645565a..551f1a4c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -53,6 +53,7 @@ import torch.types import yaml import zmq import zmq.asyncio +from packaging import version from packaging.version import Version from torch.library import Library from typing_extensions import Never, ParamSpec, TypeIs, assert_never @@ -2580,3 +2581,20 @@ def sha256(input) -> int: input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) return int.from_bytes(hashlib.sha256(input_bytes).digest(), byteorder="big") + + +def is_torch_equal_or_newer(target: str) -> bool: + """Check if the installed torch version is >= the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + torch_version = version.parse(str(torch.__version__)) + return torch_version >= version.parse(target) + except Exception: + # Fallback to PKG-INFO to load the package info, needed by the doc gen. + return Version(importlib.metadata.version('torch')) >= Version(target)