Check for _is_cuda() in compute_num_jobs (#3481)

This commit is contained in:
bnellnm 2024-03-20 13:06:56 -04:00 committed by GitHub
parent 84eaa68425
commit ba8ae1d84f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -61,12 +61,12 @@ class cmake_build_ext(build_ext):
except AttributeError:
num_jobs = os.cpu_count()
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version >= Version("11.2"):
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
num_jobs = max(1, round(num_jobs / (nvcc_threads / 4)))
else:
nvcc_threads = None
nvcc_threads = None
if _is_cuda():
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version >= Version("11.2"):
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
num_jobs = max(1, round(num_jobs / (nvcc_threads / 4)))
return num_jobs, nvcc_threads