setup correct nvcc version with CUDA_HOME (#15725)
Signed-off-by: Yang Chen <yangche@fb.com>
This commit is contained in:
parent
8dd41d6bcc
commit
f3aca1ee30
8
setup.py
8
setup.py
@ -201,6 +201,9 @@ class cmake_build_ext(build_ext):
|
||||
else:
|
||||
# Default build tool to whatever cmake picks.
|
||||
build_tool = []
|
||||
# Make sure we use the nvcc from CUDA_HOME
|
||||
if _is_cuda():
|
||||
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']
|
||||
subprocess.check_call(
|
||||
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
|
||||
cwd=self.build_temp)
|
||||
@ -639,11 +642,10 @@ if _is_hip():
|
||||
|
||||
if _is_cuda():
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"):
|
||||
# FA3 requires CUDA 12.0 or later
|
||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
||||
# FA3 requires CUDA 12.3 or later
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
||||
# Optional since this doesn't get built (produce an .so file) when
|
||||
# not targeting a hopper system
|
||||
ext_modules.append(
|
||||
|
Loading…
x
Reference in New Issue
Block a user