[Hardware][AMD][CI/Build][Doc] Upgrade to ROCm 6.1, Dockerfile improvements, test fixes (#5422)
This commit is contained in:
parent
bc34937d68
commit
dd793d1de5
@ -32,8 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
|
|||||||
# versions are derived from Dockerfile.rocm
|
# versions are derived from Dockerfile.rocm
|
||||||
#
|
#
|
||||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
|
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
|
||||||
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
|
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
|
||||||
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Try to find python package with an executable that exactly matches
|
# Try to find python package with an executable that exactly matches
|
||||||
@ -98,18 +97,11 @@ elseif(HIP_FOUND)
|
|||||||
# .hip extension automatically, HIP must be enabled explicitly.
|
# .hip extension automatically, HIP must be enabled explicitly.
|
||||||
enable_language(HIP)
|
enable_language(HIP)
|
||||||
|
|
||||||
# ROCm 5.x
|
# ROCm 5.X and 6.X
|
||||||
if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
|
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
|
||||||
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
|
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
|
||||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
|
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
|
||||||
"expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
|
"expected for ROCm build, saw ${Torch_VERSION} instead.")
|
||||||
endif()
|
|
||||||
|
|
||||||
# ROCm 6.x
|
|
||||||
if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
|
|
||||||
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
|
|
||||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
|
|
||||||
"expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
|
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
|
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
|
||||||
|
205
Dockerfile.rocm
205
Dockerfile.rocm
@ -1,34 +1,35 @@
|
|||||||
# default base image
|
# Default ROCm 6.1 base image
|
||||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
|
||||||
|
|
||||||
FROM $BASE_IMAGE
|
# Tested and supported base rocm/pytorch images
|
||||||
|
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
|
||||||
|
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
|
||||||
|
ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
|
||||||
|
|
||||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
# Default ROCm ARCHes to build vLLM for.
|
||||||
|
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
|
||||||
|
|
||||||
RUN echo "Base image is $BASE_IMAGE"
|
# Whether to build CK-based flash-attention
|
||||||
|
# If 0, will not build flash attention
|
||||||
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
|
# This is useful for gfx target where flash-attention is not supported
|
||||||
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
|
||||||
|
# Triton FA is used by default on ROCm now so this is unnecessary.
|
||||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
|
||||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
|
||||||
|
|
||||||
ARG FA_BRANCH="ae7928c"
|
|
||||||
RUN echo "FA_BRANCH is $FA_BRANCH"
|
|
||||||
|
|
||||||
# whether to build flash-attention
|
|
||||||
# if 0, will not build flash attention
|
|
||||||
# this is useful for gfx target where flash-attention is not supported
|
|
||||||
# In that case, we need to use the python reference attention implementation in vllm
|
|
||||||
ARG BUILD_FA="1"
|
ARG BUILD_FA="1"
|
||||||
|
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||||
|
ARG FA_BRANCH="ae7928c"
|
||||||
|
|
||||||
# whether to build triton on rocm
|
# Whether to build triton on rocm
|
||||||
ARG BUILD_TRITON="1"
|
ARG BUILD_TRITON="1"
|
||||||
|
ARG TRITON_BRANCH="0ef1848"
|
||||||
|
|
||||||
|
### Base image build stage
|
||||||
|
FROM $BASE_IMAGE AS base
|
||||||
|
|
||||||
|
# Import arg(s) defined before this build stage
|
||||||
|
ARG PYTORCH_ROCM_ARCH
|
||||||
|
|
||||||
# Install some basic utilities
|
# Install some basic utilities
|
||||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||||
|
|
||||||
# Install some basic utilities
|
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
curl \
|
curl \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
@ -39,79 +40,159 @@ RUN apt-get update && apt-get install -y \
|
|||||||
build-essential \
|
build-essential \
|
||||||
wget \
|
wget \
|
||||||
unzip \
|
unzip \
|
||||||
nvidia-cuda-toolkit \
|
|
||||||
tmux \
|
tmux \
|
||||||
ccache \
|
ccache \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
### Mount Point ###
|
# When launching the container, mount the code directory to /vllm-workspace
|
||||||
# When launching the container, mount the code directory to /app
|
|
||||||
ARG APP_MOUNT=/vllm-workspace
|
ARG APP_MOUNT=/vllm-workspace
|
||||||
VOLUME [ ${APP_MOUNT} ]
|
|
||||||
WORKDIR ${APP_MOUNT}
|
WORKDIR ${APP_MOUNT}
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
RUN pip install --upgrade pip
|
||||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
# Remove sccache so it doesn't interfere with ccache
|
||||||
|
# TODO: implement sccache support across components
|
||||||
|
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||||
|
# Install torch == 2.4.0 on ROCm
|
||||||
|
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||||
|
*"rocm-5.7"*) \
|
||||||
|
pip uninstall -y torch \
|
||||||
|
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
|
||||||
|
--index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
|
||||||
|
*"rocm-6.0"*) \
|
||||||
|
pip uninstall -y torch \
|
||||||
|
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
|
||||||
|
--index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
|
||||||
|
*"rocm-6.1"*) \
|
||||||
|
pip uninstall -y torch \
|
||||||
|
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
|
||||||
|
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
|
||||||
|
*) ;; esac
|
||||||
|
|
||||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||||
|
|
||||||
# Install ROCm flash-attention
|
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||||
RUN if [ "$BUILD_FA" = "1" ]; then \
|
ENV CCACHE_DIR=/root/.cache/ccache
|
||||||
mkdir libs \
|
|
||||||
|
|
||||||
|
### AMD-SMI build stage
|
||||||
|
FROM base AS build_amdsmi
|
||||||
|
# Build amdsmi wheel always
|
||||||
|
RUN cd /opt/rocm/share/amd_smi \
|
||||||
|
&& pip wheel . --wheel-dir=/install
|
||||||
|
|
||||||
|
|
||||||
|
### Flash-Attention wheel build stage
|
||||||
|
FROM base AS build_fa
|
||||||
|
ARG BUILD_FA
|
||||||
|
ARG FA_GFX_ARCHS
|
||||||
|
ARG FA_BRANCH
|
||||||
|
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
|
||||||
|
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||||
|
if [ "$BUILD_FA" = "1" ]; then \
|
||||||
|
mkdir -p libs \
|
||||||
&& cd libs \
|
&& cd libs \
|
||||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||||
&& cd flash-attention \
|
&& cd flash-attention \
|
||||||
&& git checkout ${FA_BRANCH} \
|
&& git checkout "${FA_BRANCH}" \
|
||||||
&& git submodule update --init \
|
&& git submodule update --init \
|
||||||
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
|
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||||
&& if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE" ]; then \
|
*"rocm-5.7"*) \
|
||||||
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
|
||||||
&& python3 setup.py install \
|
&& patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
|
||||||
&& cd ..; \
|
*) ;; esac \
|
||||||
|
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||||
|
# Create an empty directory otherwise as later build stages expect one
|
||||||
|
else mkdir -p /install; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
|
||||||
# Manually removed it so that later steps of numpy upgrade can continue
|
|
||||||
RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
|
|
||||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
|
||||||
|
|
||||||
# build triton
|
### Triton wheel build stage
|
||||||
RUN if [ "$BUILD_TRITON" = "1" ]; then \
|
FROM base AS build_triton
|
||||||
|
ARG BUILD_TRITON
|
||||||
|
ARG TRITON_BRANCH
|
||||||
|
# Build triton wheel if `BUILD_TRITON = 1`
|
||||||
|
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||||
|
if [ "$BUILD_TRITON" = "1" ]; then \
|
||||||
mkdir -p libs \
|
mkdir -p libs \
|
||||||
&& cd libs \
|
&& cd libs \
|
||||||
&& pip uninstall -y triton \
|
&& git clone https://github.com/OpenAI/triton.git \
|
||||||
&& git clone https://github.com/ROCm/triton.git \
|
&& cd triton \
|
||||||
&& cd triton/python \
|
&& git checkout "${TRITON_BRANCH}" \
|
||||||
&& pip3 install . \
|
&& cd python \
|
||||||
&& cd ../..; \
|
&& python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||||
|
# Create an empty directory otherwise as later build stages expect one
|
||||||
|
else mkdir -p /install; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
WORKDIR /vllm-workspace
|
|
||||||
|
### Final vLLM build stage
|
||||||
|
FROM base AS final
|
||||||
|
# Import the vLLM development directory from the build context
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
#RUN python3 -m pip install pynvml # to be removed eventually
|
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||||
RUN python3 -m pip install --upgrade pip numba
|
# Manually remove it so that later steps of numpy upgrade can continue
|
||||||
|
RUN case "$(which python3)" in \
|
||||||
|
*"/opt/conda/envs/py_3.9"*) \
|
||||||
|
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
|
||||||
|
*) ;; esac
|
||||||
|
|
||||||
# make sure punica kernels are built (for LoRA)
|
# Package upgrades for useful functionality or to avoid dependency issues
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install --upgrade numba scipy huggingface-hub[cli]
|
||||||
|
|
||||||
|
# Make sure punica kernels are built (for LoRA)
|
||||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||||
# Workaround for ray >= 2.10.0
|
# Workaround for ray >= 2.10.0
|
||||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||||
|
# Silences the HF Tokenizers warning
|
||||||
|
ENV TOKENIZERS_PARALLELISM=false
|
||||||
|
|
||||||
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
|
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||||
|
|
||||||
ENV CCACHE_DIR=/root/.cache/ccache
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install -U -r requirements-rocm.txt \
|
pip install -U -r requirements-rocm.txt \
|
||||||
&& if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
|
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||||
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \
|
*"rocm-6.0"*) \
|
||||||
&& python3 setup.py install \
|
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
|
||||||
&& export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \
|
*"rocm-6.1"*) \
|
||||||
&& cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \
|
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
|
||||||
&& cd ..
|
wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
|
||||||
|
&& cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
|
||||||
|
# Prevent interference if torch bundles its own HIP runtime
|
||||||
|
&& rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
|
||||||
|
*) ;; esac \
|
||||||
|
&& python3 setup.py clean --all \
|
||||||
|
&& python3 setup.py develop
|
||||||
|
|
||||||
|
# Copy amdsmi wheel into final image
|
||||||
|
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
|
||||||
|
mkdir -p libs \
|
||||||
|
&& cp /install/*.whl libs \
|
||||||
|
# Preemptively uninstall to avoid same-version no-installs
|
||||||
|
&& pip uninstall -y amdsmi;
|
||||||
|
|
||||||
|
# Copy triton wheel(s) into final image if they were built
|
||||||
|
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
||||||
|
mkdir -p libs \
|
||||||
|
&& if ls /install/*.whl; then \
|
||||||
|
cp /install/*.whl libs \
|
||||||
|
# Preemptively uninstall to avoid same-version no-installs
|
||||||
|
&& pip uninstall -y triton; fi
|
||||||
|
|
||||||
|
# Copy flash-attn wheel(s) into final image if they were built
|
||||||
|
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
||||||
|
mkdir -p libs \
|
||||||
|
&& if ls /install/*.whl; then \
|
||||||
|
cp /install/*.whl libs \
|
||||||
|
# Preemptively uninstall to avoid same-version no-installs
|
||||||
|
&& pip uninstall -y flash-attn; fi
|
||||||
|
|
||||||
|
# Install wheels that were built to the final image
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
if ls libs/*.whl; then \
|
||||||
|
pip install libs/*.whl; fi
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
@ -147,19 +147,23 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
|||||||
if (${GPU_LANG} STREQUAL "HIP")
|
if (${GPU_LANG} STREQUAL "HIP")
|
||||||
#
|
#
|
||||||
# `GPU_ARCHES` controls the `--offload-arch` flags.
|
# `GPU_ARCHES` controls the `--offload-arch` flags.
|
||||||
# `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled
|
|
||||||
# via the `PYTORCH_ROCM_ARCH` env variable.
|
|
||||||
#
|
#
|
||||||
|
# If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
|
||||||
|
# if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
|
||||||
|
# "rocm_agent_enumerator" in "enable_language(HIP)"
|
||||||
|
# (in file Modules/CMakeDetermineHIPCompiler.cmake)
|
||||||
|
#
|
||||||
|
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
|
||||||
|
set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
|
||||||
|
else()
|
||||||
|
set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
|
||||||
|
endif()
|
||||||
#
|
#
|
||||||
# Find the intersection of the supported + detected architectures to
|
# Find the intersection of the supported + detected architectures to
|
||||||
# set the module architecture flags.
|
# set the module architecture flags.
|
||||||
#
|
#
|
||||||
|
|
||||||
set(VLLM_ROCM_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
|
|
||||||
|
|
||||||
set(${GPU_ARCHES})
|
set(${GPU_ARCHES})
|
||||||
foreach (_ARCH ${VLLM_ROCM_SUPPORTED_ARCHS})
|
foreach (_ARCH ${HIP_ARCHITECTURES})
|
||||||
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
||||||
list(APPEND ${GPU_ARCHES} ${_ARCH})
|
list(APPEND ${GPU_ARCHES} ${_ARCH})
|
||||||
endif()
|
endif()
|
||||||
@ -167,7 +171,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
|||||||
|
|
||||||
if(NOT ${GPU_ARCHES})
|
if(NOT ${GPU_ARCHES})
|
||||||
message(FATAL_ERROR
|
message(FATAL_ERROR
|
||||||
"None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is"
|
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
|
||||||
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ Option 2: Build from source
|
|||||||
- `Pytorch <https://pytorch.org/>`_
|
- `Pytorch <https://pytorch.org/>`_
|
||||||
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
|
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
|
||||||
|
|
||||||
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
|
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
|
||||||
|
|
||||||
Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started <https://pytorch.org/get-started/locally/>`_
|
Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started <https://pytorch.org/get-started/locally/>`_
|
||||||
|
|
||||||
@ -126,12 +126,12 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
|
|||||||
|
|
||||||
$ cd vllm
|
$ cd vllm
|
||||||
$ pip install -U -r requirements-rocm.txt
|
$ pip install -U -r requirements-rocm.txt
|
||||||
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
|
$ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
|
||||||
|
|
||||||
|
|
||||||
.. tip::
|
.. tip::
|
||||||
|
|
||||||
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
|
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
|
||||||
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
|
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
|
||||||
- To use CK flash-attention, please use this flag ``export VLLM_USE_FLASH_ATTN_TRITON=0`` to turn off triton flash attention.
|
- To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
|
||||||
- The ROCm version of pytorch, ideally, should match the ROCm driver version.
|
- The ROCm version of pytorch, ideally, should match the ROCm driver version.
|
||||||
|
@ -4,7 +4,7 @@ import pytest
|
|||||||
# and debugging.
|
# and debugging.
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "facebook/opt-125m"
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
@ -12,7 +12,7 @@ MODEL_NAME = "facebook/opt-125m"
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def ray_ctx():
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
ray.init()
|
||||||
yield
|
yield
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
from vllm.utils import cuda_device_count_stateless
|
import vllm.envs as envs
|
||||||
|
from vllm.utils import (cuda_device_count_stateless, is_hip,
|
||||||
|
update_environment_variables)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
@ -12,16 +12,21 @@ class _CUDADeviceCountStatelessTestActor:
|
|||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|
||||||
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
update_environment_variables(
|
||||||
|
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||||
|
|
||||||
def get_cuda_visible_devices(self):
|
def get_cuda_visible_devices(self):
|
||||||
return os.environ["CUDA_VISIBLE_DEVICES"]
|
return envs.CUDA_VISIBLE_DEVICES
|
||||||
|
|
||||||
|
|
||||||
def test_cuda_device_count_stateless():
|
def test_cuda_device_count_stateless():
|
||||||
"""Test that cuda_device_count_stateless changes return value if
|
"""Test that cuda_device_count_stateless changes return value if
|
||||||
CUDA_VISIBLE_DEVICES is changed."""
|
CUDA_VISIBLE_DEVICES is changed."""
|
||||||
|
if is_hip():
|
||||||
|
# Set HIP_VISIBLE_DEVICES == CUDA_VISIBLE_DEVICES. Conversion
|
||||||
|
# is handled by `update_environment_variables`
|
||||||
|
update_environment_variables(
|
||||||
|
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
|
||||||
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
||||||
num_gpus=2).remote()
|
num_gpus=2).remote()
|
||||||
assert sorted(ray.get(
|
assert sorted(ray.get(
|
||||||
|
@ -2,7 +2,7 @@ import openai
|
|||||||
import pytest
|
import pytest
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
|
||||||
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ pytestmark = pytest.mark.openai
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def ray_ctx():
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
ray.init()
|
||||||
yield
|
yield
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from openai import BadRequestError
|
|||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
@ -81,7 +81,7 @@ def zephyr_lora_files():
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def ray_ctx():
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
ray.init()
|
||||||
yield
|
yield
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import ray
|
|||||||
|
|
||||||
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
|
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
|
||||||
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
|
||||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||||
LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent /
|
LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent /
|
||||||
@ -27,7 +27,7 @@ pytestmark = pytest.mark.openai
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def ray_ctx():
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
ray.init()
|
||||||
yield
|
yield
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
|
@ -15,9 +15,30 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
from vllm.utils import get_open_port, is_hip
|
from vllm.utils import get_open_port, is_hip
|
||||||
|
|
||||||
if (not is_hip()):
|
if is_hip():
|
||||||
|
from amdsmi import (amdsmi_get_gpu_vram_usage,
|
||||||
|
amdsmi_get_processor_handles, amdsmi_init,
|
||||||
|
amdsmi_shut_down)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _nvml():
|
||||||
|
try:
|
||||||
|
amdsmi_init()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
amdsmi_shut_down()
|
||||||
|
else:
|
||||||
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
|
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
|
||||||
nvmlInit)
|
nvmlInit, nvmlShutdown)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _nvml():
|
||||||
|
try:
|
||||||
|
nvmlInit()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
nvmlShutdown()
|
||||||
|
|
||||||
|
|
||||||
# Path to root of repository so that utilities can be imported by ray workers
|
# Path to root of repository so that utilities can be imported by ray workers
|
||||||
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
|
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
|
||||||
@ -160,20 +181,25 @@ def error_on_warning():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@_nvml()
|
||||||
def wait_for_gpu_memory_to_clear(devices: List[int],
|
def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||||
threshold_bytes: int,
|
threshold_bytes: int,
|
||||||
timeout_s: float = 120) -> None:
|
timeout_s: float = 120) -> None:
|
||||||
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
||||||
# context.
|
# context.
|
||||||
nvmlInit()
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
output: Dict[int, str] = {}
|
output: Dict[int, str] = {}
|
||||||
output_raw: Dict[int, float] = {}
|
output_raw: Dict[int, float] = {}
|
||||||
for device in devices:
|
for device in devices:
|
||||||
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
if is_hip():
|
||||||
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
dev_handle = amdsmi_get_processor_handles()[device]
|
||||||
gb_used = mem_info.used / 2**30
|
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
|
||||||
|
gb_used = mem_info["vram_used"] / 2**10
|
||||||
|
else:
|
||||||
|
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
||||||
|
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
||||||
|
gb_used = mem_info.used / 2**30
|
||||||
output_raw[device] = gb_used
|
output_raw[device] = gb_used
|
||||||
output[device] = f'{gb_used:.02f}'
|
output[device] = f'{gb_used:.02f}'
|
||||||
|
|
||||||
|
@ -7,13 +7,15 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizerBase
|
from transformers import PretrainedConfig, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.tracing import is_otel_installed
|
from vllm.tracing import is_otel_installed
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||||
is_hip, is_neuron, is_tpu, is_xpu)
|
is_hip, is_neuron, is_tpu, is_xpu,
|
||||||
|
update_environment_variables)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
@ -634,6 +636,12 @@ class ParallelConfig:
|
|||||||
self.distributed_executor_backend = backend
|
self.distributed_executor_backend = backend
|
||||||
logger.info("Defaulting to use %s for distributed inference",
|
logger.info("Defaulting to use %s for distributed inference",
|
||||||
backend)
|
backend)
|
||||||
|
# If CUDA_VISIBLE_DEVICES is set on ROCm prior to vLLM init,
|
||||||
|
# propagate changes to HIP_VISIBLE_DEVICES (conversion handled by
|
||||||
|
# the update_environment_variables function)
|
||||||
|
if is_hip() and envs.CUDA_VISIBLE_DEVICES:
|
||||||
|
update_environment_variables(
|
||||||
|
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
|
@ -13,7 +13,8 @@ import torch.multiprocessing as mp
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cuda_device_count_stateless
|
from vllm.utils import (cuda_device_count_stateless,
|
||||||
|
update_environment_variables)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -24,7 +25,8 @@ def producer(batch_src: Sequence[int],
|
|||||||
result_queue,
|
result_queue,
|
||||||
cuda_visible_devices: Optional[str] = None):
|
cuda_visible_devices: Optional[str] = None):
|
||||||
if cuda_visible_devices is not None:
|
if cuda_visible_devices is not None:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
update_environment_variables(
|
||||||
|
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||||
|
|
||||||
lib = CudaRTLibrary()
|
lib = CudaRTLibrary()
|
||||||
for i in batch_src:
|
for i in batch_src:
|
||||||
@ -56,7 +58,8 @@ def consumer(batch_tgt: Sequence[int],
|
|||||||
result_queue,
|
result_queue,
|
||||||
cuda_visible_devices: Optional[str] = None):
|
cuda_visible_devices: Optional[str] = None):
|
||||||
if cuda_visible_devices is not None:
|
if cuda_visible_devices is not None:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
update_environment_variables(
|
||||||
|
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||||
|
|
||||||
lib = CudaRTLibrary()
|
lib = CudaRTLibrary()
|
||||||
for j in batch_tgt:
|
for j in batch_tgt:
|
||||||
@ -123,7 +126,7 @@ def can_actually_p2p(
|
|||||||
processes for testing all pairs of GPUs in batch. The trick is to reset
|
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||||
the device after each test (which is not available in PyTorch).
|
the device after each test (which is not available in PyTorch).
|
||||||
""" # noqa
|
""" # noqa
|
||||||
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
|
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||||
# to make sure they see the same set of GPUs
|
# to make sure they see the same set of GPUs
|
||||||
|
|
||||||
|
@ -11,7 +11,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import (cuda_device_count_stateless,
|
from vllm.utils import (cuda_device_count_stateless,
|
||||||
get_distributed_init_method, get_open_port,
|
get_distributed_init_method, get_open_port,
|
||||||
get_vllm_instance_id, make_async)
|
get_vllm_instance_id, make_async,
|
||||||
|
update_environment_variables)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -25,8 +26,9 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
|
|
||||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
|
update_environment_variables({
|
||||||
map(str, range(world_size))))
|
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||||
|
})
|
||||||
|
|
||||||
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
|
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
|
||||||
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
|
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
|
||||||
|
@ -376,6 +376,10 @@ def get_open_port() -> int:
|
|||||||
|
|
||||||
|
|
||||||
def update_environment_variables(envs: Dict[str, str]):
|
def update_environment_variables(envs: Dict[str, str]):
|
||||||
|
if is_hip() and "CUDA_VISIBLE_DEVICES" in envs:
|
||||||
|
# Propagate changes to CUDA_VISIBLE_DEVICES to
|
||||||
|
# ROCm's HIP_VISIBLE_DEVICES as well
|
||||||
|
envs["HIP_VISIBLE_DEVICES"] = envs["CUDA_VISIBLE_DEVICES"]
|
||||||
for k, v in envs.items():
|
for k, v in envs.items():
|
||||||
if k in os.environ and os.environ[k] != v:
|
if k in os.environ and os.environ[k] != v:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -779,9 +783,14 @@ def _cuda_device_count_stateless(
|
|||||||
|
|
||||||
if not torch.cuda._is_compiled():
|
if not torch.cuda._is_compiled():
|
||||||
return 0
|
return 0
|
||||||
# bypass _device_count_nvml() if rocm (not supported)
|
if is_hip():
|
||||||
nvml_count = -1 if torch.version.hip else torch.cuda._device_count_nvml()
|
# ROCm uses amdsmi instead of nvml for stateless device count
|
||||||
r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
|
# This requires a sufficiently modern version of Torch 2.4.0
|
||||||
|
raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
|
||||||
|
torch.cuda, "_device_count_amdsmi")) else -1
|
||||||
|
else:
|
||||||
|
raw_count = torch.cuda._device_count_nvml()
|
||||||
|
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
@ -795,7 +804,6 @@ def cuda_device_count_stateless() -> int:
|
|||||||
|
|
||||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||||
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||||
|
|
||||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
|
||||||
update_environment_variables)
|
update_environment_variables)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -125,6 +125,14 @@ class WorkerWrapperBase:
|
|||||||
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
||||||
# suppress the warning in `update_environment_variables`
|
# suppress the warning in `update_environment_variables`
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
if is_hip():
|
||||||
|
hip_env_var = "HIP_VISIBLE_DEVICES"
|
||||||
|
if hip_env_var in os.environ:
|
||||||
|
logger.warning(
|
||||||
|
"Ignoring pre-set environment variable `%s=%s` as "
|
||||||
|
"%s has also been set, which takes precedence.",
|
||||||
|
hip_env_var, os.environ[hip_env_var], key)
|
||||||
|
os.environ.pop(hip_env_var, None)
|
||||||
update_environment_variables(envs)
|
update_environment_variables(envs)
|
||||||
|
|
||||||
def init_worker(self, *args, **kwargs):
|
def init_worker(self, *args, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user