#!/bin/bash # Currently FP8 benchmark is NOT enabled. set -x server_params=$1 common_params=$2 json2args() { # transforms the JSON string to command line args, and '_' is replaced to '-' # example: # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 local json_string=$1 local args=$( echo "$json_string" | jq -r ' to_entries | map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | join(" ") ' ) echo "$args" } launch_trt_server() { model_path=$(echo "$common_params" | jq -r '.model') model_name="${model_path#*/}" model_type=$(echo "$server_params" | jq -r '.model_type') model_dtype=$(echo "$server_params" | jq -r '.model_dtype') model_tp_size=$(echo "$common_params" | jq -r '.tp') max_batch_size=$(echo "$server_params" | jq -r '.max_batch_size') max_input_len=$(echo "$server_params" | jq -r '.max_input_len') max_seq_len=$(echo "$server_params" | jq -r '.max_seq_len') max_num_tokens=$(echo "$server_params" | jq -r '.max_num_tokens') trt_llm_version=$(echo "$server_params" | jq -r '.trt_llm_version') # create model caching directory cd ~ rm -rf models mkdir -p models cd models models_dir=$(pwd) trt_model_path=${models_dir}/${model_name}-trt-ckpt trt_engine_path=${models_dir}/${model_name}-trt-engine # clone tensorrt backend cd / rm -rf tensorrtllm_backend git clone https://github.com/triton-inference-server/tensorrtllm_backend.git git lfs install cd tensorrtllm_backend git checkout $trt_llm_version tensorrtllm_backend_dir=$(pwd) git submodule update --init --recursive # build trtllm engine cd /tensorrtllm_backend cd ./tensorrt_llm/examples/${model_type} python3 convert_checkpoint.py \ --model_dir ${model_path} \ --dtype ${model_dtype} \ --tp_size ${model_tp_size} \ --output_dir ${trt_model_path} trtllm-build \ --checkpoint_dir ${trt_model_path} \ --use_fused_mlp \ --reduce_fusion disable \ --workers 8 \ --gpt_attention_plugin ${model_dtype} \ --gemm_plugin ${model_dtype} \ --tp_size ${model_tp_size} \ --max_batch_size ${max_batch_size} \ --max_input_len ${max_input_len} \ --max_seq_len ${max_seq_len} \ --max_num_tokens ${max_num_tokens} \ --output_dir ${trt_engine_path} # handle triton protobuf files and launch triton server cd /tensorrtllm_backend mkdir triton_model_repo cp -r all_models/inflight_batcher_llm/* triton_model_repo/ cd triton_model_repo rm -rf ./tensorrt_llm/1/* cp -r ${trt_engine_path}/* ./tensorrt_llm/1 python3 ../tools/fill_template.py -i tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,engine_dir:/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1,decoupled_mode:true,batching_strategy:inflight_fused_batching,batch_scheduler_policy:guaranteed_no_evict,exclude_input_in_output:true,triton_max_batch_size:2048,max_queue_delay_microseconds:0,max_beam_width:1,max_queue_size:2048,enable_kv_cache_reuse:false python3 ../tools/fill_template.py -i preprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path,preprocessing_instance_count:5 python3 ../tools/fill_template.py -i postprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path,postprocessing_instance_count:5,skip_special_tokens:false python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:$max_batch_size python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt triton_max_batch_size:$max_batch_size,decoupled_mode:true,accumulate_tokens:"False",bls_instance_count:1 cd /tensorrtllm_backend python3 scripts/launch_triton_server.py \ --world_size=${model_tp_size} \ --model_repo=/tensorrtllm_backend/triton_model_repo & } launch_tgi_server() { model=$(echo "$common_params" | jq -r '.model') tp=$(echo "$common_params" | jq -r '.tp') dataset_name=$(echo "$common_params" | jq -r '.dataset_name') dataset_path=$(echo "$common_params" | jq -r '.dataset_path') port=$(echo "$common_params" | jq -r '.port') num_prompts=$(echo "$common_params" | jq -r '.num_prompts') server_args=$(json2args "$server_params") if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then echo "Key 'fp8' exists in common params." server_command="/tgi-entrypoint.sh \ --model-id $model \ --num-shard $tp \ --port $port \ --quantize fp8 \ $server_args" else echo "Key 'fp8' does not exist in common params." server_command="/tgi-entrypoint.sh \ --model-id $model \ --num-shard $tp \ --port $port \ $server_args" fi echo "Server command: $server_command" eval "$server_command" & } launch_lmdeploy_server() { model=$(echo "$common_params" | jq -r '.model') tp=$(echo "$common_params" | jq -r '.tp') dataset_name=$(echo "$common_params" | jq -r '.dataset_name') dataset_path=$(echo "$common_params" | jq -r '.dataset_path') port=$(echo "$common_params" | jq -r '.port') num_prompts=$(echo "$common_params" | jq -r '.num_prompts') server_args=$(json2args "$server_params") server_command="lmdeploy serve api_server $model \ --tp $tp \ --server-port $port \ $server_args" # run the server echo "Server command: $server_command" bash -c "$server_command" & } launch_sglang_server() { model=$(echo "$common_params" | jq -r '.model') tp=$(echo "$common_params" | jq -r '.tp') dataset_name=$(echo "$common_params" | jq -r '.dataset_name') dataset_path=$(echo "$common_params" | jq -r '.dataset_path') port=$(echo "$common_params" | jq -r '.port') num_prompts=$(echo "$common_params" | jq -r '.num_prompts') server_args=$(json2args "$server_params") if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') server_command="python3 \ -m sglang.launch_server \ --tp $tp \ --model-path $model \ --port $port \ $server_args" else echo "Key 'fp8' does not exist in common params." server_command="python3 \ -m sglang.launch_server \ --tp $tp \ --model-path $model \ --port $port \ $server_args" fi # run the server echo "Server command: $server_command" eval "$server_command" & } launch_vllm_server() { export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') model=$(echo "$common_params" | jq -r '.model') tp=$(echo "$common_params" | jq -r '.tp') dataset_name=$(echo "$common_params" | jq -r '.dataset_name') dataset_path=$(echo "$common_params" | jq -r '.dataset_path') port=$(echo "$common_params" | jq -r '.port') num_prompts=$(echo "$common_params" | jq -r '.num_prompts') server_args=$(json2args "$server_params") if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') server_command="python3 \ -m vllm.entrypoints.openai.api_server \ -tp $tp \ --model $model \ --port $port \ $server_args" else echo "Key 'fp8' does not exist in common params." server_command="python3 \ -m vllm.entrypoints.openai.api_server \ -tp $tp \ --model $model \ --port $port \ $server_args" fi # run the server echo "Server command: $server_command" eval "$server_command" & } main() { if [[ $CURRENT_LLM_SERVING_ENGINE == "trt" ]]; then launch_trt_server fi if [[ $CURRENT_LLM_SERVING_ENGINE == "tgi" ]]; then launch_tgi_server fi if [[ $CURRENT_LLM_SERVING_ENGINE == "lmdeploy" ]]; then launch_lmdeploy_server fi if [[ $CURRENT_LLM_SERVING_ENGINE == "sglang" ]]; then launch_sglang_server fi if [[ "$CURRENT_LLM_SERVING_ENGINE" == *"vllm"* ]]; then launch_vllm_server fi } main