[CI/Build] Add shell script linting using shellcheck (#7925)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2024-11-07 13:17:29 -05:00 committed by GitHub
parent de0e61a323
commit 3be5b26a76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 204 additions and 129 deletions

View File

@ -41,6 +41,6 @@ while getopts "m:b:l:f:" OPT; do
done done
lm_eval --model hf \ lm_eval --model hf \
--model_args pretrained=$MODEL,parallelize=True \ --model_args "pretrained=$MODEL,parallelize=True" \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
--batch_size $BATCH_SIZE --batch_size "$BATCH_SIZE"

View File

@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done done
lm_eval --model vllm \ lm_eval --model vllm \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \ --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
--batch_size $BATCH_SIZE --batch_size "$BATCH_SIZE"

View File

@ -30,7 +30,7 @@ while getopts "c:t:" OPT; do
done done
# Parse list of configs. # Parse list of configs.
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG"
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do do

View File

@ -50,31 +50,30 @@ launch_trt_server() {
git clone https://github.com/triton-inference-server/tensorrtllm_backend.git git clone https://github.com/triton-inference-server/tensorrtllm_backend.git
git lfs install git lfs install
cd tensorrtllm_backend cd tensorrtllm_backend
git checkout $trt_llm_version git checkout "$trt_llm_version"
tensorrtllm_backend_dir=$(pwd)
git submodule update --init --recursive git submodule update --init --recursive
# build trtllm engine # build trtllm engine
cd /tensorrtllm_backend cd /tensorrtllm_backend
cd ./tensorrt_llm/examples/${model_type} cd "./tensorrt_llm/examples/${model_type}"
python3 convert_checkpoint.py \ python3 convert_checkpoint.py \
--model_dir ${model_path} \ --model_dir "${model_path}" \
--dtype ${model_dtype} \ --dtype "${model_dtype}" \
--tp_size ${model_tp_size} \ --tp_size "${model_tp_size}" \
--output_dir ${trt_model_path} --output_dir "${trt_model_path}"
trtllm-build \ trtllm-build \
--checkpoint_dir ${trt_model_path} \ --checkpoint_dir "${trt_model_path}" \
--use_fused_mlp \ --use_fused_mlp \
--reduce_fusion disable \ --reduce_fusion disable \
--workers 8 \ --workers 8 \
--gpt_attention_plugin ${model_dtype} \ --gpt_attention_plugin "${model_dtype}" \
--gemm_plugin ${model_dtype} \ --gemm_plugin "${model_dtype}" \
--tp_size ${model_tp_size} \ --tp_size "${model_tp_size}" \
--max_batch_size ${max_batch_size} \ --max_batch_size "${max_batch_size}" \
--max_input_len ${max_input_len} \ --max_input_len "${max_input_len}" \
--max_seq_len ${max_seq_len} \ --max_seq_len "${max_seq_len}" \
--max_num_tokens ${max_num_tokens} \ --max_num_tokens "${max_num_tokens}" \
--output_dir ${trt_engine_path} --output_dir "${trt_engine_path}"
# handle triton protobuf files and launch triton server # handle triton protobuf files and launch triton server
cd /tensorrtllm_backend cd /tensorrtllm_backend
@ -82,15 +81,15 @@ launch_trt_server() {
cp -r all_models/inflight_batcher_llm/* triton_model_repo/ cp -r all_models/inflight_batcher_llm/* triton_model_repo/
cd triton_model_repo cd triton_model_repo
rm -rf ./tensorrt_llm/1/* rm -rf ./tensorrt_llm/1/*
cp -r ${trt_engine_path}/* ./tensorrt_llm/1 cp -r "${trt_engine_path}"/* ./tensorrt_llm/1
python3 ../tools/fill_template.py -i tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,engine_dir:/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1,decoupled_mode:true,batching_strategy:inflight_fused_batching,batch_scheduler_policy:guaranteed_no_evict,exclude_input_in_output:true,triton_max_batch_size:2048,max_queue_delay_microseconds:0,max_beam_width:1,max_queue_size:2048,enable_kv_cache_reuse:false python3 ../tools/fill_template.py -i tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,engine_dir:/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1,decoupled_mode:true,batching_strategy:inflight_fused_batching,batch_scheduler_policy:guaranteed_no_evict,exclude_input_in_output:true,triton_max_batch_size:2048,max_queue_delay_microseconds:0,max_beam_width:1,max_queue_size:2048,enable_kv_cache_reuse:false
python3 ../tools/fill_template.py -i preprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path,preprocessing_instance_count:5 python3 ../tools/fill_template.py -i preprocessing/config.pbtxt "triton_max_batch_size:2048,tokenizer_dir:$model_path,preprocessing_instance_count:5"
python3 ../tools/fill_template.py -i postprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path,postprocessing_instance_count:5,skip_special_tokens:false python3 ../tools/fill_template.py -i postprocessing/config.pbtxt "triton_max_batch_size:2048,tokenizer_dir:$model_path,postprocessing_instance_count:5,skip_special_tokens:false"
python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:$max_batch_size python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:"$max_batch_size"
python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt triton_max_batch_size:$max_batch_size,decoupled_mode:true,accumulate_tokens:"False",bls_instance_count:1 python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt "triton_max_batch_size:$max_batch_size,decoupled_mode:true,accumulate_tokens:False,bls_instance_count:1"
cd /tensorrtllm_backend cd /tensorrtllm_backend
python3 scripts/launch_triton_server.py \ python3 scripts/launch_triton_server.py \
--world_size=${model_tp_size} \ --world_size="${model_tp_size}" \
--model_repo=/tensorrtllm_backend/triton_model_repo & --model_repo=/tensorrtllm_backend/triton_model_repo &
} }
@ -98,10 +97,7 @@ launch_trt_server() {
launch_tgi_server() { launch_tgi_server() {
model=$(echo "$common_params" | jq -r '.model') model=$(echo "$common_params" | jq -r '.model')
tp=$(echo "$common_params" | jq -r '.tp') tp=$(echo "$common_params" | jq -r '.tp')
dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
port=$(echo "$common_params" | jq -r '.port') port=$(echo "$common_params" | jq -r '.port')
num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
server_args=$(json2args "$server_params") server_args=$(json2args "$server_params")
if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then
@ -129,10 +125,7 @@ launch_tgi_server() {
launch_lmdeploy_server() { launch_lmdeploy_server() {
model=$(echo "$common_params" | jq -r '.model') model=$(echo "$common_params" | jq -r '.model')
tp=$(echo "$common_params" | jq -r '.tp') tp=$(echo "$common_params" | jq -r '.tp')
dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
port=$(echo "$common_params" | jq -r '.port') port=$(echo "$common_params" | jq -r '.port')
num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
server_args=$(json2args "$server_params") server_args=$(json2args "$server_params")
server_command="lmdeploy serve api_server $model \ server_command="lmdeploy serve api_server $model \
@ -149,10 +142,7 @@ launch_sglang_server() {
model=$(echo "$common_params" | jq -r '.model') model=$(echo "$common_params" | jq -r '.model')
tp=$(echo "$common_params" | jq -r '.tp') tp=$(echo "$common_params" | jq -r '.tp')
dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
port=$(echo "$common_params" | jq -r '.port') port=$(echo "$common_params" | jq -r '.port')
num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
server_args=$(json2args "$server_params") server_args=$(json2args "$server_params")
if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then
@ -185,10 +175,7 @@ launch_vllm_server() {
model=$(echo "$common_params" | jq -r '.model') model=$(echo "$common_params" | jq -r '.model')
tp=$(echo "$common_params" | jq -r '.tp') tp=$(echo "$common_params" | jq -r '.tp')
dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
port=$(echo "$common_params" | jq -r '.port') port=$(echo "$common_params" | jq -r '.port')
num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
server_args=$(json2args "$server_params") server_args=$(json2args "$server_params")
if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then
@ -217,19 +204,19 @@ launch_vllm_server() {
main() { main() {
if [[ $CURRENT_LLM_SERVING_ENGINE == "trt" ]]; then if [[ "$CURRENT_LLM_SERVING_ENGINE" == "trt" ]]; then
launch_trt_server launch_trt_server
fi fi
if [[ $CURRENT_LLM_SERVING_ENGINE == "tgi" ]]; then if [[ "$CURRENT_LLM_SERVING_ENGINE" == "tgi" ]]; then
launch_tgi_server launch_tgi_server
fi fi
if [[ $CURRENT_LLM_SERVING_ENGINE == "lmdeploy" ]]; then if [[ "$CURRENT_LLM_SERVING_ENGINE" == "lmdeploy" ]]; then
launch_lmdeploy_server launch_lmdeploy_server
fi fi
if [[ $CURRENT_LLM_SERVING_ENGINE == "sglang" ]]; then if [[ "$CURRENT_LLM_SERVING_ENGINE" == "sglang" ]]; then
launch_sglang_server launch_sglang_server
fi fi

View File

@ -16,10 +16,10 @@ main() {
fi fi
# initial annotation # initial annotation
description="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-descriptions.md" #description="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-descriptions.md"
# download results # download results
cd $VLLM_SOURCE_CODE_LOC/benchmarks cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
mkdir -p results/ mkdir -p results/
/workspace/buildkite-agent artifact download 'results/*nightly_results.json' results/ /workspace/buildkite-agent artifact download 'results/*nightly_results.json' results/
ls ls
@ -30,15 +30,15 @@ main() {
/workspace/buildkite-agent artifact upload "results.zip" /workspace/buildkite-agent artifact upload "results.zip"
# upload benchmarking scripts # upload benchmarking scripts
cd $VLLM_SOURCE_CODE_LOC/ cd "$VLLM_SOURCE_CODE_LOC/"
zip -r nightly-benchmarks.zip .buildkite/ benchmarks/ zip -r nightly-benchmarks.zip .buildkite/ benchmarks/
/workspace/buildkite-agent artifact upload "nightly-benchmarks.zip" /workspace/buildkite-agent artifact upload "nightly-benchmarks.zip"
cd $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/ cd "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/"
# upload benchmarking pipeline # upload benchmarking pipeline
/workspace/buildkite-agent artifact upload "nightly-pipeline.yaml" /workspace/buildkite-agent artifact upload "nightly-pipeline.yaml"
cd $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/ cd "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/"
/workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly-annotation.md /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly-annotation.md

View File

@ -12,7 +12,7 @@ check_gpus() {
echo "Need at least 1 GPU to run benchmarking." echo "Need at least 1 GPU to run benchmarking."
exit 1 exit 1
fi fi
declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') declare -g gpu_type="$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')"
echo "GPU type is $gpu_type" echo "GPU type is $gpu_type"
} }
@ -102,7 +102,7 @@ kill_gpu_processes() {
pkill -f text-generation pkill -f text-generation
pkill -f lmdeploy pkill -f lmdeploy
while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do
sleep 1 sleep 1
done done
} }
@ -119,8 +119,8 @@ wait_for_server() {
ensure_installed() { ensure_installed() {
# Ensure that the given command is installed by apt-get # Ensure that the given command is installed by apt-get
local cmd=$1 local cmd=$1
if ! which $cmd >/dev/null; then if ! which "$cmd" >/dev/null; then
apt-get update && apt-get install -y $cmd apt-get update && apt-get install -y "$cmd"
fi fi
} }
@ -173,13 +173,11 @@ run_serving_tests() {
echo "Reuse previous server for test case $test_name" echo "Reuse previous server for test case $test_name"
else else
kill_gpu_processes kill_gpu_processes
bash $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh \ bash "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh" \
"$server_params" "$common_params" "$server_params" "$common_params"
fi fi
wait_for_server if wait_for_server; then
if [ $? -eq 0 ]; then
echo "" echo ""
echo "$CURRENT_LLM_SERVING_ENGINE server is up and running." echo "$CURRENT_LLM_SERVING_ENGINE server is up and running."
else else
@ -190,13 +188,13 @@ run_serving_tests() {
# prepare tokenizer # prepare tokenizer
# this is required for lmdeploy. # this is required for lmdeploy.
cd $VLLM_SOURCE_CODE_LOC/benchmarks cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
rm -rf /tokenizer_cache rm -rf /tokenizer_cache
mkdir /tokenizer_cache mkdir /tokenizer_cache
python3 ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \ python3 ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \
--model "$model" \ --model "$model" \
--cachedir /tokenizer_cache --cachedir /tokenizer_cache
cd $VLLM_SOURCE_CODE_LOC/benchmarks cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
# change model name for lmdeploy (it will not follow standard hf name) # change model name for lmdeploy (it will not follow standard hf name)
@ -307,11 +305,11 @@ run_serving_tests() {
prepare_dataset() { prepare_dataset() {
# download sharegpt dataset # download sharegpt dataset
cd $VLLM_SOURCE_CODE_LOC/benchmarks cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
# duplicate sonnet by 4x, to allow benchmarking with input length 2048 # duplicate sonnet by 4x, to allow benchmarking with input length 2048
cd $VLLM_SOURCE_CODE_LOC/benchmarks cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
echo "" > sonnet_4x.txt echo "" > sonnet_4x.txt
for _ in {1..4} for _ in {1..4}
do do
@ -339,17 +337,17 @@ main() {
prepare_dataset prepare_dataset
cd $VLLM_SOURCE_CODE_LOC/benchmarks cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
declare -g RESULTS_FOLDER=results/ declare -g RESULTS_FOLDER=results/
mkdir -p $RESULTS_FOLDER mkdir -p $RESULTS_FOLDER
BENCHMARK_ROOT=$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/ BENCHMARK_ROOT="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/"
# run the test # run the test
run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json run_serving_tests "$BENCHMARK_ROOT/tests/nightly-tests.json"
# upload benchmark results to buildkite # upload benchmark results to buildkite
python3 -m pip install tabulate pandas python3 -m pip install tabulate pandas
python3 $BENCHMARK_ROOT/scripts/summary-nightly-results.py python3 "$BENCHMARK_ROOT/scripts/summary-nightly-results.py"
upload_to_buildkite upload_to_buildkite
} }

View File

@ -17,7 +17,7 @@ check_gpus() {
echo "Need at least 1 GPU to run benchmarking." echo "Need at least 1 GPU to run benchmarking."
exit 1 exit 1
fi fi
declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') declare -g gpu_type=$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')
echo "GPU type is $gpu_type" echo "GPU type is $gpu_type"
} }
@ -93,7 +93,7 @@ kill_gpu_processes() {
# wait until GPU memory usage smaller than 1GB # wait until GPU memory usage smaller than 1GB
while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do
sleep 1 sleep 1
done done
@ -117,7 +117,7 @@ upload_to_buildkite() {
fi fi
# Use the determined command to annotate and upload artifacts # Use the determined command to annotate and upload artifacts
$BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" <$RESULTS_FOLDER/benchmark_results.md $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < "$RESULTS_FOLDER/benchmark_results.md"
$BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*" $BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*"
} }
@ -150,7 +150,7 @@ run_latency_tests() {
# check if there is enough GPU to run the test # check if there is enough GPU to run the test
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
if [[ $gpu_count -lt $tp ]]; then if [[ $gpu_count -lt $tp ]]; then
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
continue continue
fi fi
@ -206,9 +206,9 @@ run_throughput_tests() {
throughput_args=$(json2args "$throughput_params") throughput_args=$(json2args "$throughput_params")
# check if there is enough GPU to run the test # check if there is enough GPU to run the test
tp=$(echo $throughput_params | jq -r '.tensor_parallel_size') tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
if [[ $gpu_count -lt $tp ]]; then if [[ $gpu_count -lt $tp ]]; then
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
continue continue
fi fi
@ -270,7 +270,7 @@ run_serving_tests() {
# check if there is enough GPU to run the test # check if there is enough GPU to run the test
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
if [[ $gpu_count -lt $tp ]]; then if [[ $gpu_count -lt $tp ]]; then
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
continue continue
fi fi
@ -278,7 +278,7 @@ run_serving_tests() {
server_model=$(echo "$server_params" | jq -r '.model') server_model=$(echo "$server_params" | jq -r '.model')
client_model=$(echo "$client_params" | jq -r '.model') client_model=$(echo "$client_params" | jq -r '.model')
if [[ $server_model != "$client_model" ]]; then if [[ $server_model != "$client_model" ]]; then
echo "Server model and client model must be the same. Skip testcase $testname." echo "Server model and client model must be the same. Skip testcase $test_name."
continue continue
fi fi
@ -293,8 +293,7 @@ run_serving_tests() {
server_pid=$! server_pid=$!
# wait until the server is alive # wait until the server is alive
wait_for_server if wait_for_server; then
if [ $? -eq 0 ]; then
echo "" echo ""
echo "vllm server is up and running." echo "vllm server is up and running."
else else

View File

@ -6,7 +6,7 @@ TIMEOUT_SECONDS=10
retries=0 retries=0
while [ $retries -lt 1000 ]; do while [ $retries -lt 1000 ]; do
if [ $(curl -s --max-time $TIMEOUT_SECONDS -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then if [ "$(curl -s --max-time "$TIMEOUT_SECONDS" -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" "$URL")" -eq 200 ]; then
exit 0 exit 0
fi fi

View File

@ -1,3 +1,5 @@
#!/bin/bash
# This script runs test inside the corresponding ROCm docker container. # This script runs test inside the corresponding ROCm docker container.
set -o pipefail set -o pipefail
@ -57,17 +59,17 @@ done
echo "--- Pulling container" echo "--- Pulling container"
image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}" image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}"
container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
docker pull ${image_name} docker pull "${image_name}"
remove_docker_container() { remove_docker_container() {
docker rm -f ${container_name} || docker image rm -f ${image_name} || true docker rm -f "${container_name}" || docker image rm -f "${image_name}" || true
} }
trap remove_docker_container EXIT trap remove_docker_container EXIT
echo "--- Running container" echo "--- Running container"
HF_CACHE="$(realpath ~)/huggingface" HF_CACHE="$(realpath ~)/huggingface"
mkdir -p ${HF_CACHE} mkdir -p "${HF_CACHE}"
HF_MOUNT="/root/.cache/huggingface" HF_MOUNT="/root/.cache/huggingface"
commands=$@ commands=$@
@ -118,25 +120,25 @@ if [[ $commands == *"--shard-id="* ]]; then
--network host \ --network host \
--shm-size=16gb \ --shm-size=16gb \
--rm \ --rm \
-e HIP_VISIBLE_DEVICES=${GPU} \ -e HIP_VISIBLE_DEVICES="${GPU}" \
-e HF_TOKEN \ -e HF_TOKEN \
-v ${HF_CACHE}:${HF_MOUNT} \ -v "${HF_CACHE}:${HF_MOUNT}" \
-e HF_HOME=${HF_MOUNT} \ -e "HF_HOME=${HF_MOUNT}" \
--name ${container_name}_${GPU} \ --name "${container_name}_${GPU}" \
${image_name} \ "${image_name}" \
/bin/bash -c "${commands_gpu}" \ /bin/bash -c "${commands_gpu}" \
|& while read -r line; do echo ">>Shard $GPU: $line"; done & |& while read -r line; do echo ">>Shard $GPU: $line"; done &
PIDS+=($!) PIDS+=($!)
done done
#wait for all processes to finish and collect exit codes #wait for all processes to finish and collect exit codes
for pid in ${PIDS[@]}; do for pid in "${PIDS[@]}"; do
wait ${pid} wait "${pid}"
STATUS+=($?) STATUS+=($?)
done done
for st in ${STATUS[@]}; do for st in "${STATUS[@]}"; do
if [[ ${st} -ne 0 ]]; then if [[ ${st} -ne 0 ]]; then
echo "One of the processes failed with $st" echo "One of the processes failed with $st"
exit ${st} exit "${st}"
fi fi
done done
else else
@ -147,9 +149,9 @@ else
--rm \ --rm \
-e HIP_VISIBLE_DEVICES=0 \ -e HIP_VISIBLE_DEVICES=0 \
-e HF_TOKEN \ -e HF_TOKEN \
-v ${HF_CACHE}:${HF_MOUNT} \ -v "${HF_CACHE}:${HF_MOUNT}" \
-e HF_HOME=${HF_MOUNT} \ -e "HF_HOME=${HF_MOUNT}" \
--name ${container_name} \ --name "${container_name}" \
${image_name} \ "${image_name}" \
/bin/bash -c "${commands}" /bin/bash -c "${commands}"
fi fi

View File

@ -1,3 +1,5 @@
#!/bin/bash
# This script is run by buildkite to run the benchmarks and upload the results to buildkite # This script is run by buildkite to run the benchmarks and upload the results to buildkite
set -ex set -ex

View File

@ -1,3 +1,5 @@
#!/bin/bash
# This script build the CPU docker image and run the offline inference inside the container. # This script build the CPU docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage. # It serves a sanity check for compilation and basic model usage.
set -ex set -ex
@ -13,7 +15,7 @@ remove_docker_container
# Run the image, setting --shm-size=4g for tensor parallel. # Run the image, setting --shm-size=4g for tensor parallel.
source /etc/environment source /etc/environment
#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test #docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN=$HF_TOKEN --name cpu-test cpu-test docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN="$HF_TOKEN" --name cpu-test cpu-test
# Run basic model test # Run basic model test
docker exec cpu-test bash -c " docker exec cpu-test bash -c "

View File

@ -1,3 +1,5 @@
#!/bin/bash
# This script build the CPU docker image and run the offline inference inside the container. # This script build the CPU docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage. # It serves a sanity check for compilation and basic model usage.
set -ex set -ex

View File

@ -14,7 +14,7 @@ DOCKER_IMAGE=$4
shift 4 shift 4
COMMANDS=("$@") COMMANDS=("$@")
if [ ${#COMMANDS[@]} -ne $NUM_NODES ]; then if [ ${#COMMANDS[@]} -ne "$NUM_NODES" ]; then
echo "The number of commands must be equal to the number of nodes." echo "The number of commands must be equal to the number of nodes."
echo "Number of nodes: $NUM_NODES" echo "Number of nodes: $NUM_NODES"
echo "Number of commands: ${#COMMANDS[@]}" echo "Number of commands: ${#COMMANDS[@]}"
@ -23,7 +23,7 @@ fi
echo "List of commands" echo "List of commands"
for command in "${COMMANDS[@]}"; do for command in "${COMMANDS[@]}"; do
echo $command echo "$command"
done done
start_network() { start_network() {
@ -36,7 +36,7 @@ start_nodes() {
for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do
DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu)) DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu))
GPU_DEVICES+=$(($DEVICE_NUM)) GPU_DEVICES+=$(($DEVICE_NUM))
if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then if [ "$node_gpu" -lt $(($NUM_GPUS - 1)) ]; then
GPU_DEVICES+=',' GPU_DEVICES+=','
fi fi
done done
@ -49,17 +49,20 @@ start_nodes() {
# 3. map the huggingface cache directory to the container # 3. map the huggingface cache directory to the container
# 3. assign ip addresses to the containers (head node: 192.168.10.10, worker nodes: # 3. assign ip addresses to the containers (head node: 192.168.10.10, worker nodes:
# starting from 192.168.10.11) # starting from 192.168.10.11)
docker run -d --gpus "$GPU_DEVICES" --shm-size=10.24gb -e HF_TOKEN -v ~/.cache/huggingface:/root/.cache/huggingface --name node$node --network docker-net --ip 192.168.10.$((10 + $node)) --rm $DOCKER_IMAGE /bin/bash -c "tail -f /dev/null" docker run -d --gpus "$GPU_DEVICES" --shm-size=10.24gb -e HF_TOKEN \
-v ~/.cache/huggingface:/root/.cache/huggingface --name "node$node" \
--network docker-net --ip 192.168.10.$((10 + $node)) --rm "$DOCKER_IMAGE" \
/bin/bash -c "tail -f /dev/null"
# organize containers into a ray cluster # organize containers into a ray cluster
if [ $node -eq 0 ]; then if [ "$node" -eq 0 ]; then
# start the ray head node # start the ray head node
docker exec -d node$node /bin/bash -c "ray start --head --port=6379 --block" docker exec -d "node$node" /bin/bash -c "ray start --head --port=6379 --block"
# wait for the head node to be ready # wait for the head node to be ready
sleep 10 sleep 10
else else
# start the ray worker nodes, and connect them to the head node # start the ray worker nodes, and connect them to the head node
docker exec -d node$node /bin/bash -c "ray start --address=192.168.10.10:6379 --block" docker exec -d "node$node" /bin/bash -c "ray start --address=192.168.10.10:6379 --block"
fi fi
done done
@ -79,22 +82,22 @@ run_nodes() {
for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do
DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu)) DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu))
GPU_DEVICES+=$(($DEVICE_NUM)) GPU_DEVICES+=$(($DEVICE_NUM))
if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then if [ "$node_gpu" -lt $(($NUM_GPUS - 1)) ]; then
GPU_DEVICES+=',' GPU_DEVICES+=','
fi fi
done done
GPU_DEVICES+='"' GPU_DEVICES+='"'
echo "Running node$node with GPU devices: $GPU_DEVICES" echo "Running node$node with GPU devices: $GPU_DEVICES"
if [ $node -ne 0 ]; then if [ "$node" -ne 0 ]; then
docker exec -d node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}" docker exec -d "node$node" /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
else else
docker exec node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}" docker exec "node$node" /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
fi fi
done done
} }
cleanup() { cleanup() {
for node in $(seq 0 $(($NUM_NODES-1))); do for node in $(seq 0 $(($NUM_NODES-1))); do
docker stop node$node docker stop "node$node"
done done
docker network rm docker-net docker network rm docker-net
} }

View File

@ -1,3 +1,5 @@
#!/bin/bash
# This script build the Neuron docker image and run the API server inside the container. # This script build the Neuron docker image and run the API server inside the container.
# It serves a sanity check for compilation and basic model usage. # It serves a sanity check for compilation and basic model usage.
set -e set -e
@ -12,10 +14,10 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then
current_time=$(date +%s) current_time=$(date +%s)
if [ $((current_time - last_build)) -gt 86400 ]; then if [ $((current_time - last_build)) -gt 86400 ]; then
docker system prune -f docker system prune -f
echo $current_time > /tmp/neuron-docker-build-timestamp echo "$current_time" > /tmp/neuron-docker-build-timestamp
fi fi
else else
echo $(date +%s) > /tmp/neuron-docker-build-timestamp date "+%s" > /tmp/neuron-docker-build-timestamp
fi fi
docker build -t neuron -f Dockerfile.neuron . docker build -t neuron -f Dockerfile.neuron .
@ -34,7 +36,7 @@ wait_for_server_to_start() {
timeout=300 timeout=300
counter=0 counter=0
while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do while [ "$(curl -s -o /dev/null -w '%{http_code}' localhost:8000/health)" != "200" ]; do
sleep 1 sleep 1
counter=$((counter + 1)) counter=$((counter + 1))
if [ $counter -ge $timeout ]; then if [ $counter -ge $timeout ]; then

View File

@ -1,3 +1,5 @@
#!/bin/bash
# This script build the OpenVINO docker image and run the offline inference inside the container. # This script build the OpenVINO docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage. # It serves a sanity check for compilation and basic model usage.
set -ex set -ex

View File

@ -1,3 +1,5 @@
#!/bin/bash
set -e set -e
# Build the docker image. # Build the docker image.
@ -12,4 +14,4 @@ remove_docker_container
# For HF_TOKEN. # For HF_TOKEN.
source /etc/environment source /etc/environment
# Run a simple end-to-end example. # Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"

View File

@ -1,3 +1,5 @@
#!/bin/bash
# This script build the CPU docker image and run the offline inference inside the container. # This script build the CPU docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage. # It serves a sanity check for compilation and basic model usage.
set -ex set -ex

View File

@ -1,16 +1,16 @@
#!/bin/bash #!/bin/bash
# Replace '.' with '-' ex: 11.8 -> 11-8 # Replace '.' with '-' ex: 11.8 -> 11-8
cuda_version=$(echo $1 | tr "." "-") cuda_version=$(echo "$1" | tr "." "-")
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004 # Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
OS=$(echo $2 | tr -d ".\-") OS=$(echo "$2" | tr -d ".\-")
# Installs CUDA # Installs CUDA
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb wget -nv "https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb"
sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb
rm cuda-keyring_1.1-1_all.deb rm cuda-keyring_1.1-1_all.deb
sudo apt -qq update sudo apt -qq update
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version} sudo apt -y install "cuda-${cuda_version}" "cuda-nvcc-${cuda_version}" "cuda-libraries-dev-${cuda_version}"
sudo apt clean sudo apt clean
# Test nvcc # Test nvcc

View File

@ -6,7 +6,7 @@ cuda_version=$3
# Install torch # Install torch
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya $python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./} $python_executable -m pip install torch=="${pytorch_version}+cu${cuda_version//./}" --extra-index-url "https://download.pytorch.org/whl/cu${cuda_version//./}"
# Print version information # Print version information
$python_executable --version $python_executable --version

37
.github/workflows/shellcheck.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Lint shell scripts
on:
push:
branches:
- "main"
paths:
- '**/*.sh'
- '.github/workflows/shellcheck.yml'
pull_request:
branches:
- "main"
paths:
- '**/*.sh'
- '.github/workflows/shellcheck.yml'
env:
LC_ALL: en_US.UTF-8
defaults:
run:
shell: bash
permissions:
contents: read
jobs:
shellcheck:
runs-on: ubuntu-latest
steps:
- name: "Checkout"
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
with:
fetch-depth: 0
- name: "Check shell scripts"
run: |
tools/shellcheck.sh

1
.gitignore vendored
View File

@ -202,3 +202,4 @@ benchmarks/*.json
# Linting # Linting
actionlint actionlint
shellcheck*/

9
.shellcheckrc Normal file
View File

@ -0,0 +1,9 @@
# rules currently disabled:
#
# SC1091 (info): Not following: <sourced file> was not specified as input (see shellcheck -x)
# SC2004 (style): $/${} is unnecessary on arithmetic variables.
# SC2129 (style): Consider using { cmd1; cmd2; } >> file instead of individual redirects.
# SC2155 (warning): Declare and assign separately to avoid masking return values.
# SC2164 (warning): Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
#
disable=SC1091,SC2004,SC2129,SC2155,SC2164

View File

@ -4,13 +4,13 @@ PORT=8000
MODEL=$1 MODEL=$1
TOKENS=$2 TOKENS=$2
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ docker run -e "HF_TOKEN=$HF_TOKEN" --gpus all --shm-size 1g -p $PORT:80 \
-v $PWD/data:/data \ -v "$PWD/data:/data" \
ghcr.io/huggingface/text-generation-inference:2.2.0 \ ghcr.io/huggingface/text-generation-inference:2.2.0 \
--model-id $MODEL \ --model-id "$MODEL" \
--sharded false \ --sharded false \
--max-input-length 1024 \ --max-input-length 1024 \
--max-total-tokens 2048 \ --max-total-tokens 2048 \
--max-best-of 5 \ --max-best-of 5 \
--max-concurrent-requests 5000 \ --max-concurrent-requests 5000 \
--max-batch-total-tokens $TOKENS --max-batch-total-tokens "$TOKENS"

View File

@ -14,7 +14,7 @@ PATH_TO_HF_HOME="$4"
shift 4 shift 4
# Additional arguments are passed directly to the Docker command # Additional arguments are passed directly to the Docker command
ADDITIONAL_ARGS="$@" ADDITIONAL_ARGS=("$@")
# Validate node type # Validate node type
if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
@ -45,5 +45,5 @@ docker run \
--shm-size 10.24g \ --shm-size 10.24g \
--gpus all \ --gpus all \
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \ -v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
${ADDITIONAL_ARGS} \ "${ADDITIONAL_ARGS[@]}" \
"${DOCKER_IMAGE}" -c "${RAY_START_CMD}" "${DOCKER_IMAGE}" -c "${RAY_START_CMD}"

View File

@ -44,14 +44,14 @@ CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
# # params: tool name, tool version, required version # # params: tool name, tool version, required version
tool_version_check() { tool_version_check() {
if [[ $2 != $3 ]]; then if [[ "$2" != "$3" ]]; then
echo "❓❓Wrong $1 version installed: $3 is required, not $2." echo "❓❓Wrong $1 version installed: $3 is required, not $2."
exit 1 exit 1
fi fi
} }
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-lint.txt | cut -d'=' -f3)" tool_version_check "yapf" "$YAPF_VERSION" "$(grep yapf requirements-lint.txt | cut -d'=' -f3)"
tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-lint.txt | cut -d'=' -f3)" tool_version_check "ruff" "$RUFF_VERSION" "$(grep "ruff==" requirements-lint.txt | cut -d'=' -f3)"
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-lint.txt | cut -d'=' -f3)" tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-lint.txt | cut -d'=' -f3)"
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-lint.txt | cut -d'=' -f3)" tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-lint.txt | cut -d'=' -f3)"
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-lint.txt | cut -d'=' -f3)" tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-lint.txt | cut -d'=' -f3)"
@ -294,6 +294,10 @@ echo 'vLLM actionlint:'
tools/actionlint.sh -color tools/actionlint.sh -color
echo 'vLLM actionlint: Done' echo 'vLLM actionlint: Done'
echo 'vLLM shellcheck:'
tools/shellcheck.sh
echo 'vLLM shellcheck: Done'
if ! git diff --quiet &>/dev/null; then if ! git diff --quiet &>/dev/null; then
echo echo
echo "🔍🔍There are files changed by the format checker or by you that are not added and committed:" echo "🔍🔍There are files changed by the format checker or by you that are not added and committed:"

View File

@ -14,7 +14,7 @@ while getopts "c:" OPT; do
done done
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG"
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do do

View File

@ -3,13 +3,13 @@
CI=${1:-0} CI=${1:-0}
PYTHON_VERSION=${2:-3.9} PYTHON_VERSION=${2:-3.9}
if [ $CI -eq 1 ]; then if [ "$CI" -eq 1 ]; then
set -e set -e
fi fi
run_mypy() { run_mypy() {
echo "Running mypy on $1" echo "Running mypy on $1"
if [ $CI -eq 1 ] && [ -z "$1" ]; then if [ "$CI" -eq 1 ] && [ -z "$1" ]; then
mypy --python-version "${PYTHON_VERSION}" "$@" mypy --python-version "${PYTHON_VERSION}" "$@"
return return
fi fi

21
tools/shellcheck.sh Executable file
View File

@ -0,0 +1,21 @@
#!/bin/bash
scversion="stable"
if [ -d "shellcheck-${scversion}" ]; then
export PATH="$PATH:$(pwd)/shellcheck-${scversion}"
fi
if ! [ -x "$(command -v shellcheck)" ]; then
if [ "$(uname -s)" != "Linux" ] || [ "$(uname -m)" != "x86_64" ]; then
echo "Please install shellcheck: https://github.com/koalaman/shellcheck?tab=readme-ov-file#installing"
exit 1
fi
# automatic local install if linux x86_64
wget -qO- "https://github.com/koalaman/shellcheck/releases/download/${scversion?}/shellcheck-${scversion?}.linux.x86_64.tar.xz" | tar -xJv
export PATH="$PATH:$(pwd)/shellcheck-${scversion}"
fi
# TODO - fix warnings in .buildkite/run-amd-test.sh
find . -name "*.sh" -not -path "./.deps/*" -not -path "./.buildkite/run-amd-test.sh" -exec shellcheck {} +