[ROCm] Fixup arch checks for ROCM (#2627)
This commit is contained in:
parent
b92adec8e8
commit
2ccee3def6
@ -10,9 +10,6 @@ RUN echo "Base image is $BASE_IMAGE"
|
|||||||
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
||||||
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||||
|
|
||||||
# this does not always work for all rocm versions
|
|
||||||
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
|
|
||||||
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
|
|
||||||
|
|
||||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
||||||
|
90
setup.py
90
setup.py
@ -19,7 +19,7 @@ MAIN_CUDA_VERSION = "12.1"
|
|||||||
|
|
||||||
# Supported NVIDIA GPU architectures.
|
# Supported NVIDIA GPU architectures.
|
||||||
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||||
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
|
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"}
|
||||||
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
|
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
|
||||||
|
|
||||||
|
|
||||||
@ -63,22 +63,6 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
|||||||
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||||
|
|
||||||
|
|
||||||
def get_amdgpu_offload_arch():
|
|
||||||
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
|
|
||||||
try:
|
|
||||||
output = subprocess.check_output([command])
|
|
||||||
return output.decode('utf-8').strip()
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
error_message = f"Error: {e}"
|
|
||||||
raise RuntimeError(error_message) from e
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
# If the command is not found, print an error message
|
|
||||||
error_message = f"The command {command} was not found."
|
|
||||||
raise RuntimeError(error_message) from e
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_hipcc_rocm_version():
|
def get_hipcc_rocm_version():
|
||||||
# Run the hipcc --version command
|
# Run the hipcc --version command
|
||||||
result = subprocess.run(['hipcc', '--version'],
|
result = subprocess.run(['hipcc', '--version'],
|
||||||
@ -138,6 +122,50 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
|||||||
return nvcc_cuda_version
|
return nvcc_cuda_version
|
||||||
|
|
||||||
|
|
||||||
|
def get_pytorch_rocm_arch() -> Set[str]:
|
||||||
|
"""Get the cross section of Pytorch,and vllm supported gfx arches
|
||||||
|
|
||||||
|
ROCM can get the supported gfx architectures in one of two ways
|
||||||
|
Either through the PYTORCH_ROCM_ARCH env var, or output from
|
||||||
|
rocm_agent_enumerator.
|
||||||
|
|
||||||
|
In either case we can generate a list of supported arch's and
|
||||||
|
cross reference with VLLM's own ROCM_SUPPORTED_ARCHs.
|
||||||
|
"""
|
||||||
|
env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)
|
||||||
|
|
||||||
|
# If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
|
||||||
|
if env_arch_list is None:
|
||||||
|
command = "rocm_agent_enumerator"
|
||||||
|
env_arch_list = subprocess.check_output([command]).decode('utf-8')\
|
||||||
|
.strip().replace("\n", ";")
|
||||||
|
arch_source_str = "rocm_agent_enumerator"
|
||||||
|
else:
|
||||||
|
arch_source_str = "PYTORCH_ROCM_ARCH env variable"
|
||||||
|
|
||||||
|
# List are separated by ; or space.
|
||||||
|
pytorch_rocm_arch = set(env_arch_list.replace(" ", ";").split(";"))
|
||||||
|
|
||||||
|
# Filter out the invalid architectures and print a warning.
|
||||||
|
arch_list = pytorch_rocm_arch.intersection(ROCM_SUPPORTED_ARCHS)
|
||||||
|
|
||||||
|
# If none of the specified architectures are valid, raise an error.
|
||||||
|
if not arch_list:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"None of the ROCM architectures in {arch_source_str} "
|
||||||
|
f"({env_arch_list}) is supported. "
|
||||||
|
f"Supported ROCM architectures are: {ROCM_SUPPORTED_ARCHS}.")
|
||||||
|
invalid_arch_list = pytorch_rocm_arch - ROCM_SUPPORTED_ARCHS
|
||||||
|
if invalid_arch_list:
|
||||||
|
warnings.warn(
|
||||||
|
f"Unsupported ROCM architectures ({invalid_arch_list}) are "
|
||||||
|
f"excluded from the {arch_source_str} output "
|
||||||
|
f"({env_arch_list}). Supported ROCM architectures are: "
|
||||||
|
f"{ROCM_SUPPORTED_ARCHS}.",
|
||||||
|
stacklevel=2)
|
||||||
|
return arch_list
|
||||||
|
|
||||||
|
|
||||||
def get_torch_arch_list() -> Set[str]:
|
def get_torch_arch_list() -> Set[str]:
|
||||||
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
|
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
|
||||||
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
|
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
|
||||||
@ -162,22 +190,27 @@ def get_torch_arch_list() -> Set[str]:
|
|||||||
# If none of the specified architectures are valid, raise an error.
|
# If none of the specified architectures are valid, raise an error.
|
||||||
if not arch_list:
|
if not arch_list:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
|
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||||
f"variable ({env_arch_list}) is supported. "
|
f"variable ({env_arch_list}) is supported. "
|
||||||
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
|
f"Supported CUDA architectures are: {valid_archs}.")
|
||||||
invalid_arch_list = torch_arch_list - valid_archs
|
invalid_arch_list = torch_arch_list - valid_archs
|
||||||
if invalid_arch_list:
|
if invalid_arch_list:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
|
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
|
||||||
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
||||||
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
|
f"({env_arch_list}). Supported CUDA architectures are: "
|
||||||
f"{valid_archs}.",
|
f"{valid_archs}.",
|
||||||
stacklevel=2)
|
stacklevel=2)
|
||||||
return arch_list
|
return arch_list
|
||||||
|
|
||||||
|
|
||||||
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
if _is_hip():
|
||||||
compute_capabilities = get_torch_arch_list()
|
rocm_arches = get_pytorch_rocm_arch()
|
||||||
|
NVCC_FLAGS += ["--offload-arch=" + arch for arch in rocm_arches]
|
||||||
|
else:
|
||||||
|
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||||
|
compute_capabilities = get_torch_arch_list()
|
||||||
|
|
||||||
if _is_cuda() and not compute_capabilities:
|
if _is_cuda() and not compute_capabilities:
|
||||||
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||||
# GPUs on the current machine.
|
# GPUs on the current machine.
|
||||||
@ -286,17 +319,6 @@ if _is_cuda():
|
|||||||
"nvcc": NVCC_FLAGS_PUNICA,
|
"nvcc": NVCC_FLAGS_PUNICA,
|
||||||
},
|
},
|
||||||
))
|
))
|
||||||
elif _is_hip():
|
|
||||||
amd_archs = os.getenv("GPU_ARCHS")
|
|
||||||
if amd_archs is None:
|
|
||||||
amd_archs = get_amdgpu_offload_arch()
|
|
||||||
for arch in amd_archs.split(";"):
|
|
||||||
if arch not in ROCM_SUPPORTED_ARCHS:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
|
||||||
f"amdgpu_arch_found: {arch}")
|
|
||||||
NVCC_FLAGS += [f"--offload-arch={arch}"]
|
|
||||||
|
|
||||||
elif _is_neuron():
|
elif _is_neuron():
|
||||||
neuronxcc_version = get_neuronxcc_version()
|
neuronxcc_version = get_neuronxcc_version()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user