ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete ARG HIPBLASLT_BRANCH="db8e93b4" ARG HIPBLAS_COMMON_BRANCH="7c1566b" ARG LEGACY_HIPBLASLT_OPTION= ARG RCCL_BRANCH="648a58d" ARG RCCL_REPO="https://github.com/ROCm/rccl" ARG TRITON_BRANCH="e5be006" ARG TRITON_REPO="https://github.com/triton-lang/triton.git" ARG PYTORCH_BRANCH="295f2ed4" ARG PYTORCH_VISION_BRANCH="v0.21.0" ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG AITER_BRANCH="8970b25b" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base ENV PATH=/opt/rocm/llvm/bin:$PATH ENV ROCM_PATH=/opt/rocm ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} ARG PYTHON_VERSION=3.12 RUN mkdir -p /app WORKDIR /app ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies RUN apt-get update -y \ && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \ && add-apt-repository ppa:deadsnakes/ppa \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION}-lib2to3 python-is-python3 \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version RUN pip install -U packaging 'cmake<4' ninja wheel setuptools pybind11 Cython FROM base AS build_hipblaslt ARG HIPBLASLT_BRANCH ARG HIPBLAS_COMMON_BRANCH # Set to "--legacy_hipblas_direct" for ROCm<=6.2 ARG LEGACY_HIPBLASLT_OPTION RUN git clone https://github.com/ROCm/hipBLAS-common.git RUN cd hipBLAS-common \ && git checkout ${HIPBLAS_COMMON_BRANCH} \ && mkdir build \ && cd build \ && cmake .. \ && make package \ && dpkg -i ./*.deb RUN git clone https://github.com/ROCm/hipBLASLt RUN cd hipBLASLt \ && git checkout ${HIPBLASLT_BRANCH} \ && apt-get install -y llvm-dev \ && ./install.sh -dc --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \ && cd build/release \ && make package RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install FROM base AS build_rccl ARG RCCL_BRANCH ARG RCCL_REPO RUN git clone ${RCCL_REPO} RUN cd rccl \ && git checkout ${RCCL_BRANCH} \ && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install FROM base AS build_triton ARG TRITON_BRANCH ARG TRITON_REPO RUN git clone ${TRITON_REPO} RUN cd triton \ && git checkout ${TRITON_BRANCH} \ && cd python \ && python3 setup.py bdist_wheel --dist-dir=dist RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install FROM base AS build_amdsmi RUN cd /opt/rocm/share/amd_smi \ && pip wheel . --wheel-dir=dist RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install FROM base AS build_pytorch ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO ARG FA_BRANCH ARG FA_REPO RUN git clone ${PYTORCH_REPO} pytorch RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ pip install -r requirements.txt && git submodule update --init --recursive \ && python3 tools/amd_build/build_amd.py \ && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl RUN git clone ${PYTORCH_VISION_REPO} vision RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ && python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl RUN git clone ${FA_REPO} RUN cd flash-attention \ && git checkout ${FA_BRANCH} \ && git submodule update --init \ && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ && cp /app/vision/dist/*.whl /app/install \ && cp /app/flash-attention/dist/*.whl /app/install FROM base AS build_aiter ARG AITER_BRANCH ARG AITER_REPO RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ pip install /install/*.whl RUN git clone --recursive ${AITER_REPO} RUN cd aiter \ && git checkout ${AITER_BRANCH} \ && git submodule update --init --recursive \ && pip install -r requirements.txt RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install FROM base AS final RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \ dpkg -i /install/*deb \ && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \ dpkg -i /install/*deb \ && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ pip install /install/*.whl RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ pip install /install/*.whl RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ pip install /install/*.whl RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ pip install /install/*.whl ARG BASE_IMAGE ARG HIPBLAS_COMMON_BRANCH ARG HIPBLASLT_BRANCH ARG LEGACY_HIPBLASLT_OPTION ARG RCCL_BRANCH ARG RCCL_REPO ARG TRITON_BRANCH ARG TRITON_REPO ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO ARG FA_BRANCH ARG FA_REPO ARG AITER_BRANCH ARG AITER_REPO RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \ && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \ && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \ && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \ && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \ && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \ && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \ && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \ && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt