[CI/Build][ROCm] Enabling LoRA tests on ROCm (#7369)
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
2ad2e5608e
commit
d1dec64243
41
.buildkite/run-amd-test.sh
Normal file → Executable file
41
.buildkite/run-amd-test.sh
Normal file → Executable file
@ -1,5 +1,5 @@
|
|||||||
# This script runs test inside the corresponding ROCm docker container.
|
# This script runs test inside the corresponding ROCm docker container.
|
||||||
set -ex
|
set -o pipefail
|
||||||
|
|
||||||
# Print ROCm version
|
# Print ROCm version
|
||||||
echo "--- Confirming Clean Initial State"
|
echo "--- Confirming Clean Initial State"
|
||||||
@ -70,6 +70,41 @@ HF_CACHE="$(realpath ~)/huggingface"
|
|||||||
mkdir -p ${HF_CACHE}
|
mkdir -p ${HF_CACHE}
|
||||||
HF_MOUNT="/root/.cache/huggingface"
|
HF_MOUNT="/root/.cache/huggingface"
|
||||||
|
|
||||||
|
commands=$@
|
||||||
|
PARALLEL_JOB_COUNT=8
|
||||||
|
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
||||||
|
if [[ $commands == *"--shard-id="* ]]; then
|
||||||
|
for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do
|
||||||
|
#replace shard arguments
|
||||||
|
commands=${@//"--shard-id= "/"--shard-id=${GPU} "}
|
||||||
|
commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
|
||||||
|
docker run \
|
||||||
|
--device /dev/kfd --device /dev/dri \
|
||||||
|
--network host \
|
||||||
|
--shm-size=16gb \
|
||||||
|
--rm \
|
||||||
|
-e HIP_VISIBLE_DEVICES=${GPU} \
|
||||||
|
-e HF_TOKEN \
|
||||||
|
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||||
|
-e HF_HOME=${HF_MOUNT} \
|
||||||
|
--name ${container_name}_${GPU} \
|
||||||
|
${image_name} \
|
||||||
|
/bin/bash -c "${commands}" \
|
||||||
|
|& while read -r line; do echo ">>Shard $GPU: $line"; done &
|
||||||
|
PIDS+=($!)
|
||||||
|
done
|
||||||
|
#wait for all processes to finish and collect exit codes
|
||||||
|
for pid in ${PIDS[@]}; do
|
||||||
|
wait ${pid}
|
||||||
|
STATUS+=($?)
|
||||||
|
done
|
||||||
|
for st in ${STATUS[@]}; do
|
||||||
|
if [[ ${st} -ne 0 ]]; then
|
||||||
|
echo "One of the processes failed with $st"
|
||||||
|
exit ${st}
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
else
|
||||||
docker run \
|
docker run \
|
||||||
--device /dev/kfd --device /dev/dri \
|
--device /dev/kfd --device /dev/dri \
|
||||||
--network host \
|
--network host \
|
||||||
@ -81,5 +116,5 @@ docker run \
|
|||||||
-e HF_HOME=${HF_MOUNT} \
|
-e HF_HOME=${HF_MOUNT} \
|
||||||
--name ${container_name} \
|
--name ${container_name} \
|
||||||
${image_name} \
|
${image_name} \
|
||||||
/bin/bash -c "${@}"
|
/bin/bash -c "${commands}"
|
||||||
|
fi
|
||||||
|
@ -218,9 +218,9 @@ steps:
|
|||||||
- pytest -v -s spec_decode
|
- pytest -v -s spec_decode
|
||||||
|
|
||||||
- label: LoRA Test %N # 30min each
|
- label: LoRA Test %N # 30min each
|
||||||
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/lora
|
- vllm/lora
|
||||||
- csrc/punica
|
|
||||||
- tests/lora
|
- tests/lora
|
||||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
@ -360,7 +360,6 @@ steps:
|
|||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/lora
|
- vllm/lora
|
||||||
- csrc/punica
|
|
||||||
- tests/lora/test_long_context
|
- tests/lora/test_long_context
|
||||||
commands:
|
commands:
|
||||||
# FIXIT: find out which code initialize cuda before running the test
|
# FIXIT: find out which code initialize cuda before running the test
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
MODEL_PATH = "google/gemma-7b"
|
MODEL_PATH = "google/gemma-7b"
|
||||||
|
|
||||||
@ -28,6 +31,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
|||||||
return generated_texts
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm")
|
||||||
def test_gemma_lora(gemma_lora_files):
|
def test_gemma_lora(gemma_lora_files):
|
||||||
llm = vllm.LLM(MODEL_PATH,
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
|
@ -7,6 +7,7 @@ import pytest
|
|||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
from .conftest import cleanup
|
from .conftest import cleanup
|
||||||
|
|
||||||
@ -17,10 +18,21 @@ class ModelWithQuantization:
|
|||||||
quantization: str
|
quantization: str
|
||||||
|
|
||||||
|
|
||||||
MODELS: List[ModelWithQuantization] = [
|
MODELS: List[ModelWithQuantization]
|
||||||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
#AWQ quantization is currently not supported in ROCm.
|
||||||
|
if is_hip():
|
||||||
|
MODELS = [
|
||||||
|
ModelWithQuantization(
|
||||||
|
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||||
|
quantization="GPTQ"),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
MODELS = [
|
||||||
|
ModelWithQuantization(
|
||||||
|
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
||||||
quantization="AWQ"),
|
quantization="AWQ"),
|
||||||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
ModelWithQuantization(
|
||||||
|
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||||
quantization="GPTQ"),
|
quantization="GPTQ"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user