[Setup] Enable TORCH_CUDA_ARCH_LIST for selecting target GPUs (#1074)

This commit is contained in:
Woosuk Kwon 2023-09-26 10:21:08 -07:00 committed by GitHub
parent bbbf86565f
commit a425bd9a9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,6 +3,7 @@ import os
import re import re
import subprocess import subprocess
from typing import List, Set from typing import List, Set
import warnings
from packaging.version import parse, Version from packaging.version import parse, Version
import setuptools import setuptools
@ -11,6 +12,9 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
# Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"]
# Compiler flags. # Compiler flags.
CXX_FLAGS = ["-g", "-O2", "-std=c++17"] CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
# TODO(woosuk): Should we use -O3? # TODO(woosuk): Should we use -O3?
@ -38,51 +42,82 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
return nvcc_cuda_version return nvcc_cuda_version
# Collect the compute capabilities of all available GPUs. def get_torch_arch_list() -> Set[str]:
device_count = torch.cuda.device_count() # TORCH_CUDA_ARCH_LIST can have one or more architectures,
compute_capabilities: Set[int] = set() # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
for i in range(device_count): # compiler to additionally include PTX code that can be runtime-compiled
# and executed on the 8.6 or newer architectures. While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS]
arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if arch_list is None:
return set()
# List are separated by ; or space.
arch_list = arch_list.replace(" ", ";").split(";")
for arch in arch_list:
if arch not in valid_arch_strs:
raise ValueError(
f"Unsupported CUDA arch ({arch}). "
f"Valid CUDA arch strings are: {valid_arch_strs}.")
return set(arch_list)
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i) major, minor = torch.cuda.get_device_capability(i)
if major < 7: if major < 7:
raise RuntimeError( raise RuntimeError(
"GPUs with compute capability less than 7.0 are not supported.") "GPUs with compute capability below 7.0 are not supported.")
compute_capabilities.add(major * 10 + minor) compute_capabilities.add(f"{major}.{minor}")
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
# If no GPU is specified nor available, add all supported architectures
# based on the NVCC CUDA version.
compute_capabilities = set(SUPPORTED_ARCHS)
if nvcc_cuda_version < Version("11.1"):
compute_capabilities.remove("8.6")
if nvcc_cuda_version < Version("11.8"):
compute_capabilities.remove("8.9")
compute_capabilities.remove("9.0")
# Validate the NVCC CUDA version. # Validate the NVCC CUDA version.
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if nvcc_cuda_version < Version("11.0"): if nvcc_cuda_version < Version("11.0"):
raise RuntimeError("CUDA 11.0 or higher is required to build the package.") raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): if nvcc_cuda_version < Version("11.1"):
if any(cc.startswith("8.6") for cc in compute_capabilities):
raise RuntimeError( raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6." "CUDA 11.1 or higher is required for compute capability 8.6.")
) if nvcc_cuda_version < Version("11.8"):
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"): if any(cc.startswith("8.9") for cc in compute_capabilities):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9. # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by # However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0. # the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9. # instead of 8.9.
compute_capabilities.remove(89) warnings.warn(
compute_capabilities.add(80) "CUDA 11.8 or higher is required for compute capability 8.9. "
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): "Targeting compute capability 8.0 instead.")
compute_capabilities = set(cc for cc in compute_capabilities
if not cc.startswith("8.9"))
compute_capabilities.add("8.0+PTX")
if any(cc.startswith("9.0") for cc in compute_capabilities):
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0." "CUDA 11.8 or higher is required for compute capability 9.0.")
)
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80}
if nvcc_cuda_version >= Version("11.1"):
compute_capabilities.add(86)
if nvcc_cuda_version >= Version("11.8"):
compute_capabilities.add(89)
compute_capabilities.add(90)
# Add target compute capabilities to NVCC flags. # Add target compute capabilities to NVCC flags.
for capability in compute_capabilities: for capability in compute_capabilities:
NVCC_FLAGS += [ num = capability[0] + capability[2]
"-gencode", f"arch=compute_{capability},code=sm_{capability}" NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
] if capability.endswith("+PTX"):
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
# Use NVCC threads to parallelize the build. # Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"): if nvcc_cuda_version >= Version("11.2"):