Fix the torch version parsing logic (#15857)

This commit is contained in:
Lu Fang 2025-04-10 07:37:47 -07:00 committed by GitHub
parent 8661c0241d
commit 7678fcd5b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 11 deletions

View File

@ -2,7 +2,6 @@
import contextlib
import copy
import hashlib
import importlib.metadata
import os
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple
@ -11,9 +10,9 @@ from unittest.mock import patch
import torch
import torch._inductor.compile_fx
import torch.fx as fx
from packaging.version import Version
from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer
class CompilerInterface:
@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface):
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context()
else:

View File

@ -1,17 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
import hashlib
import importlib.metadata
import inspect
import json
import types
from typing import Any, Callable, Dict, Optional, Union
import torch
from packaging.version import Version
from torch import fx
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
from vllm.utils import is_torch_equal_or_newer
if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass
else:
# CustomGraphPass is not present in 2.5 or lower, import our version

View File

@ -4,7 +4,6 @@ import ast
import copy
import enum
import hashlib
import importlib.metadata
import json
import sys
import warnings
@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, Union)
import torch
from packaging.version import Version
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
@ -40,8 +38,8 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, random_uuid,
resolve_obj_by_qualname)
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel):
# and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
if is_torch_equal_or_newer("2.6"):
KEY = 'enable_auto_functionalized_v2'
if KEY not in self.inductor_compile_config:
self.inductor_compile_config[KEY] = False

View File

@ -53,6 +53,7 @@ import torch.types
import yaml
import zmq
import zmq.asyncio
from packaging import version
from packaging.version import Version
from torch.library import Library
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
@ -2580,3 +2581,20 @@ def sha256(input) -> int:
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try:
torch_version = version.parse(str(torch.__version__))
return torch_version >= version.parse(target)
except Exception:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return Version(importlib.metadata.version('torch')) >= Version(target)