173 lines
6.8 KiB
Docker
173 lines
6.8 KiB
Docker
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
|
|
ARG HIPBLASLT_BRANCH="4d40e36"
|
|
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
|
|
ARG LEGACY_HIPBLASLT_OPTION=
|
|
ARG RCCL_BRANCH="648a58d"
|
|
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
|
ARG TRITON_BRANCH="e5be006"
|
|
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
|
ARG PYTORCH_BRANCH="3a585126"
|
|
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
|
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
|
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
|
ARG FA_BRANCH="b7d29fb"
|
|
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
|
|
ARG AITER_BRANCH="21d47a9"
|
|
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
|
|
|
FROM ${BASE_IMAGE} AS base
|
|
|
|
ENV PATH=/opt/rocm/llvm/bin:$PATH
|
|
ENV ROCM_PATH=/opt/rocm
|
|
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
|
|
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
|
|
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
|
|
|
ARG PYTHON_VERSION=3.12
|
|
|
|
RUN mkdir -p /app
|
|
WORKDIR /app
|
|
ENV DEBIAN_FRONTEND=noninteractive
|
|
|
|
# Install Python and other dependencies
|
|
RUN apt-get update -y \
|
|
&& apt-get install -y software-properties-common git curl sudo vim less \
|
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
|
&& apt-get update -y \
|
|
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
|
python${PYTHON_VERSION}-lib2to3 python-is-python3 \
|
|
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
|
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
|
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
|
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
|
&& python3 --version && python3 -m pip --version
|
|
|
|
RUN pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
|
|
|
|
FROM base AS build_hipblaslt
|
|
ARG HIPBLASLT_BRANCH
|
|
ARG HIPBLAS_COMMON_BRANCH
|
|
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
|
|
ARG LEGACY_HIPBLASLT_OPTION
|
|
RUN git clone https://github.com/ROCm/hipBLAS-common.git
|
|
RUN cd hipBLAS-common \
|
|
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
|
|
&& mkdir build \
|
|
&& cd build \
|
|
&& cmake .. \
|
|
&& make package \
|
|
&& dpkg -i ./*.deb
|
|
RUN git clone https://github.com/ROCm/hipBLASLt
|
|
RUN cd hipBLASLt \
|
|
&& git checkout ${HIPBLASLT_BRANCH} \
|
|
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
|
|
&& cd build/release \
|
|
&& make package
|
|
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
|
|
|
|
FROM base AS build_rccl
|
|
ARG RCCL_BRANCH
|
|
ARG RCCL_REPO
|
|
RUN git clone ${RCCL_REPO}
|
|
RUN cd rccl \
|
|
&& git checkout ${RCCL_BRANCH} \
|
|
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
|
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
|
|
|
|
FROM base AS build_triton
|
|
ARG TRITON_BRANCH
|
|
ARG TRITON_REPO
|
|
RUN git clone ${TRITON_REPO}
|
|
RUN cd triton \
|
|
&& git checkout ${TRITON_BRANCH} \
|
|
&& cd python \
|
|
&& python3 setup.py bdist_wheel --dist-dir=dist
|
|
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
|
|
|
|
FROM base AS build_amdsmi
|
|
RUN cd /opt/rocm/share/amd_smi \
|
|
&& pip wheel . --wheel-dir=dist
|
|
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
|
|
|
|
FROM base AS build_pytorch
|
|
ARG PYTORCH_BRANCH
|
|
ARG PYTORCH_VISION_BRANCH
|
|
ARG PYTORCH_REPO
|
|
ARG PYTORCH_VISION_REPO
|
|
ARG FA_BRANCH
|
|
ARG FA_REPO
|
|
RUN git clone ${PYTORCH_REPO} pytorch
|
|
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
|
|
pip install -r requirements.txt && git submodule update --init --recursive \
|
|
&& python3 tools/amd_build/build_amd.py \
|
|
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
|
|
&& pip install dist/*.whl
|
|
RUN git clone ${PYTORCH_VISION_REPO} vision
|
|
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
|
|
&& python3 setup.py bdist_wheel --dist-dir=dist \
|
|
&& pip install dist/*.whl
|
|
RUN git clone ${FA_REPO}
|
|
RUN cd flash-attention \
|
|
&& git checkout ${FA_BRANCH} \
|
|
&& git submodule update --init \
|
|
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
|
|
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
|
|
&& cp /app/vision/dist/*.whl /app/install \
|
|
&& cp /app/flash-attention/dist/*.whl /app/install
|
|
|
|
FROM base AS final
|
|
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
|
dpkg -i /install/*deb \
|
|
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
|
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
|
|
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
|
|
dpkg -i /install/*deb \
|
|
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
|
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
|
|
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
|
pip install /install/*.whl
|
|
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
|
pip install /install/*.whl
|
|
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
|
pip install /install/*.whl
|
|
|
|
ARG AITER_REPO
|
|
ARG AITER_BRANCH
|
|
RUN git clone --recursive ${AITER_REPO}
|
|
RUN cd aiter \
|
|
&& git checkout ${AITER_BRANCH} \
|
|
&& git submodule update --init --recursive \
|
|
&& pip install -r requirements.txt \
|
|
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
|
|
|
|
ARG BASE_IMAGE
|
|
ARG HIPBLASLT_BRANCH
|
|
ARG HIPBLAS_COMMON_BRANCH
|
|
ARG LEGACY_HIPBLASLT_OPTION
|
|
ARG RCCL_BRANCH
|
|
ARG RCCL_REPO
|
|
ARG TRITON_BRANCH
|
|
ARG TRITON_REPO
|
|
ARG PYTORCH_BRANCH
|
|
ARG PYTORCH_VISION_BRANCH
|
|
ARG PYTORCH_REPO
|
|
ARG PYTORCH_VISION_REPO
|
|
ARG FA_BRANCH
|
|
ARG FA_REPO
|
|
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
|
&& echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
|
|
&& echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
|
|
&& echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
|
|
&& echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \
|
|
&& echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \
|
|
&& echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
|
|
&& echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
|
|
&& echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
|
|
&& echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
|
|
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
|
|
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
|
|
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
|
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
|
|
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
|
|
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|