[ROCm][Bugfix] Fixed several bugs related to rccl path and attention selector logic (#3699)
This commit is contained in:
parent
430530fc18
commit
9765b5c406
@ -90,6 +90,6 @@ RUN cd /app \
|
|||||||
&& cd ..
|
&& cd ..
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
RUN python3 -m pip install --upgrade pip
|
||||||
RUN python3 -m pip install --no-cache-dir ray[all]
|
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
@ -5,7 +5,7 @@ starlette
|
|||||||
requests
|
requests
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
psutil
|
psutil
|
||||||
ray >= 2.9
|
ray == 2.9.3
|
||||||
sentencepiece # Required for LLaMA tokenizer.
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
numpy
|
numpy
|
||||||
tokenizers>=0.15.0
|
tokenizers>=0.15.0
|
||||||
|
@ -405,8 +405,8 @@ def _check_use_naive_attention() -> bool:
|
|||||||
if not is_hip():
|
if not is_hip():
|
||||||
return False
|
return False
|
||||||
# For ROCm, check whether flash attention is installed or not.
|
# For ROCm, check whether flash attention is installed or not.
|
||||||
has_flash_attn = importlib.util.find_spec("flash_attn") is None
|
use_naive_attention = importlib.util.find_spec("flash_attn") is None
|
||||||
if not has_flash_attn:
|
if use_naive_attention:
|
||||||
logger.warning("flash_attn is not installed. Using naive attention. "
|
logger.warning("flash_attn is not installed. Using naive attention. "
|
||||||
"This will take significantly more GPU memory.")
|
"This will take significantly more GPU memory.")
|
||||||
return True
|
return True
|
||||||
|
@ -41,7 +41,7 @@ else:
|
|||||||
if torch.version.cuda is not None:
|
if torch.version.cuda is not None:
|
||||||
so_file = "libnccl.so.2"
|
so_file = "libnccl.so.2"
|
||||||
elif torch.version.hip is not None:
|
elif torch.version.hip is not None:
|
||||||
so_file = "librccl.so.2"
|
so_file = "librccl.so.1"
|
||||||
else:
|
else:
|
||||||
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||||
logger.debug(f"Loading nccl from library {so_file}")
|
logger.debug(f"Loading nccl from library {so_file}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user