setup correct nvcc version with CUDA_HOME (#15725)
Signed-off-by: Yang Chen <yangche@fb.com>
This commit is contained in:
parent
8dd41d6bcc
commit
f3aca1ee30
8
setup.py
8
setup.py
@ -201,6 +201,9 @@ class cmake_build_ext(build_ext):
|
|||||||
else:
|
else:
|
||||||
# Default build tool to whatever cmake picks.
|
# Default build tool to whatever cmake picks.
|
||||||
build_tool = []
|
build_tool = []
|
||||||
|
# Make sure we use the nvcc from CUDA_HOME
|
||||||
|
if _is_cuda():
|
||||||
|
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
|
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
|
||||||
cwd=self.build_temp)
|
cwd=self.build_temp)
|
||||||
@ -639,11 +642,10 @@ if _is_hip():
|
|||||||
|
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
||||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"):
|
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
||||||
# FA3 requires CUDA 12.0 or later
|
# FA3 requires CUDA 12.3 or later
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
|
||||||
# Optional since this doesn't get built (produce an .so file) when
|
# Optional since this doesn't get built (produce an .so file) when
|
||||||
# not targeting a hopper system
|
# not targeting a hopper system
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user