diff --git a/setup.py b/setup.py index 3a92d5a2..cf2acb20 100755 --- a/setup.py +++ b/setup.py @@ -201,6 +201,9 @@ class cmake_build_ext(build_ext): else: # Default build tool to whatever cmake picks. build_tool = [] + # Make sure we use the nvcc from CUDA_HOME + if _is_cuda(): + cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc'] subprocess.check_call( ['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args], cwd=self.build_temp) @@ -639,11 +642,10 @@ if _is_hip(): if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) - if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"): - # FA3 requires CUDA 12.0 or later + if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"): + # FA3 requires CUDA 12.3 or later ext_modules.append( CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) - if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"): # Optional since this doesn't get built (produce an .so file) when # not targeting a hopper system ext_modules.append(