Fix error message on TORCH_CUDA_ARCH_LIST
(#1239)
Co-authored-by: Yunfeng Bai <yunfeng.bai@scale.com>
This commit is contained in:
parent
de89472897
commit
d0740dff1b
37
setup.py
37
setup.py
@ -13,7 +13,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
|
||||
# Supported NVIDIA GPU architectures.
|
||||
SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"]
|
||||
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
|
||||
# Compiler flags.
|
||||
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
||||
@ -49,19 +49,32 @@ def get_torch_arch_list() -> Set[str]:
|
||||
# and executed on the 8.6 or newer architectures. While the PTX code will
|
||||
# not give the best performance on the newer architectures, it provides
|
||||
# forward compatibility.
|
||||
valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS]
|
||||
arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
||||
if arch_list is None:
|
||||
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
||||
if env_arch_list is None:
|
||||
return set()
|
||||
|
||||
# List are separated by ; or space.
|
||||
arch_list = arch_list.replace(" ", ";").split(";")
|
||||
for arch in arch_list:
|
||||
if arch not in valid_arch_strs:
|
||||
raise ValueError(
|
||||
f"Unsupported CUDA arch ({arch}). "
|
||||
f"Valid CUDA arch strings are: {valid_arch_strs}.")
|
||||
return set(arch_list)
|
||||
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
|
||||
if not torch_arch_list:
|
||||
return set()
|
||||
|
||||
# Filter out the invalid architectures and print a warning.
|
||||
valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
|
||||
arch_list = torch_arch_list.intersection(valid_archs)
|
||||
# If none of the specified architectures are valid, raise an error.
|
||||
if not arch_list:
|
||||
raise RuntimeError(
|
||||
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
f"variable ({env_arch_list}) is supported. "
|
||||
f"Supported CUDA architectures are: {valid_archs}.")
|
||||
invalid_arch_list = torch_arch_list - valid_archs
|
||||
if invalid_arch_list:
|
||||
warnings.warn(
|
||||
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
|
||||
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
||||
f"({env_arch_list}). Supported CUDA architectures are: "
|
||||
f"{valid_archs}.")
|
||||
return arch_list
|
||||
|
||||
|
||||
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||
@ -81,7 +94,7 @@ nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = set(SUPPORTED_ARCHS)
|
||||
compute_capabilities = SUPPORTED_ARCHS.copy()
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user