[ROCm] Fixup arch checks for ROCM (#2627)
This commit is contained in:
parent
b92adec8e8
commit
2ccee3def6
@ -10,9 +10,6 @@ RUN echo "Base image is $BASE_IMAGE"
|
||||
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
||||
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
# this does not always work for all rocm versions
|
||||
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
|
||||
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
|
||||
|
||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
||||
|
86
setup.py
86
setup.py
@ -19,7 +19,7 @@ MAIN_CUDA_VERSION = "12.1"
|
||||
|
||||
# Supported NVIDIA GPU architectures.
|
||||
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
|
||||
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"}
|
||||
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
|
||||
|
||||
|
||||
@ -63,22 +63,6 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
|
||||
|
||||
def get_amdgpu_offload_arch():
|
||||
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
|
||||
try:
|
||||
output = subprocess.check_output([command])
|
||||
return output.decode('utf-8').strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_message = f"Error: {e}"
|
||||
raise RuntimeError(error_message) from e
|
||||
except FileNotFoundError as e:
|
||||
# If the command is not found, print an error message
|
||||
error_message = f"The command {command} was not found."
|
||||
raise RuntimeError(error_message) from e
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_hipcc_rocm_version():
|
||||
# Run the hipcc --version command
|
||||
result = subprocess.run(['hipcc', '--version'],
|
||||
@ -138,6 +122,50 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
return nvcc_cuda_version
|
||||
|
||||
|
||||
def get_pytorch_rocm_arch() -> Set[str]:
|
||||
"""Get the cross section of Pytorch,and vllm supported gfx arches
|
||||
|
||||
ROCM can get the supported gfx architectures in one of two ways
|
||||
Either through the PYTORCH_ROCM_ARCH env var, or output from
|
||||
rocm_agent_enumerator.
|
||||
|
||||
In either case we can generate a list of supported arch's and
|
||||
cross reference with VLLM's own ROCM_SUPPORTED_ARCHs.
|
||||
"""
|
||||
env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)
|
||||
|
||||
# If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
|
||||
if env_arch_list is None:
|
||||
command = "rocm_agent_enumerator"
|
||||
env_arch_list = subprocess.check_output([command]).decode('utf-8')\
|
||||
.strip().replace("\n", ";")
|
||||
arch_source_str = "rocm_agent_enumerator"
|
||||
else:
|
||||
arch_source_str = "PYTORCH_ROCM_ARCH env variable"
|
||||
|
||||
# List are separated by ; or space.
|
||||
pytorch_rocm_arch = set(env_arch_list.replace(" ", ";").split(";"))
|
||||
|
||||
# Filter out the invalid architectures and print a warning.
|
||||
arch_list = pytorch_rocm_arch.intersection(ROCM_SUPPORTED_ARCHS)
|
||||
|
||||
# If none of the specified architectures are valid, raise an error.
|
||||
if not arch_list:
|
||||
raise RuntimeError(
|
||||
f"None of the ROCM architectures in {arch_source_str} "
|
||||
f"({env_arch_list}) is supported. "
|
||||
f"Supported ROCM architectures are: {ROCM_SUPPORTED_ARCHS}.")
|
||||
invalid_arch_list = pytorch_rocm_arch - ROCM_SUPPORTED_ARCHS
|
||||
if invalid_arch_list:
|
||||
warnings.warn(
|
||||
f"Unsupported ROCM architectures ({invalid_arch_list}) are "
|
||||
f"excluded from the {arch_source_str} output "
|
||||
f"({env_arch_list}). Supported ROCM architectures are: "
|
||||
f"{ROCM_SUPPORTED_ARCHS}.",
|
||||
stacklevel=2)
|
||||
return arch_list
|
||||
|
||||
|
||||
def get_torch_arch_list() -> Set[str]:
|
||||
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
|
||||
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
|
||||
@ -162,22 +190,27 @@ def get_torch_arch_list() -> Set[str]:
|
||||
# If none of the specified architectures are valid, raise an error.
|
||||
if not arch_list:
|
||||
raise RuntimeError(
|
||||
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
f"variable ({env_arch_list}) is supported. "
|
||||
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
|
||||
f"Supported CUDA architectures are: {valid_archs}.")
|
||||
invalid_arch_list = torch_arch_list - valid_archs
|
||||
if invalid_arch_list:
|
||||
warnings.warn(
|
||||
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
|
||||
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
|
||||
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
||||
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
|
||||
f"({env_arch_list}). Supported CUDA architectures are: "
|
||||
f"{valid_archs}.",
|
||||
stacklevel=2)
|
||||
return arch_list
|
||||
|
||||
|
||||
if _is_hip():
|
||||
rocm_arches = get_pytorch_rocm_arch()
|
||||
NVCC_FLAGS += ["--offload-arch=" + arch for arch in rocm_arches]
|
||||
else:
|
||||
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||
compute_capabilities = get_torch_arch_list()
|
||||
|
||||
if _is_cuda() and not compute_capabilities:
|
||||
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||
# GPUs on the current machine.
|
||||
@ -286,17 +319,6 @@ if _is_cuda():
|
||||
"nvcc": NVCC_FLAGS_PUNICA,
|
||||
},
|
||||
))
|
||||
elif _is_hip():
|
||||
amd_archs = os.getenv("GPU_ARCHS")
|
||||
if amd_archs is None:
|
||||
amd_archs = get_amdgpu_offload_arch()
|
||||
for arch in amd_archs.split(";"):
|
||||
if arch not in ROCM_SUPPORTED_ARCHS:
|
||||
raise RuntimeError(
|
||||
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
||||
f"amdgpu_arch_found: {arch}")
|
||||
NVCC_FLAGS += [f"--offload-arch={arch}"]
|
||||
|
||||
elif _is_neuron():
|
||||
neuronxcc_version = get_neuronxcc_version()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user