2025-02-02 14:58:18 -05:00
# SPDX-License-Identifier: Apache-2.0
2024-06-15 12:45:31 +08:00
from typing import List
2025-02-06 23:02:51 +05:30
import pytest
2024-03-26 09:09:31 +08:00
import vllm
2024-11-24 09:23:17 +08:00
from tests . utils import fork_new_process_for_each_test
2024-03-26 09:09:31 +08:00
from vllm . lora . request import LoRARequest
2024-11-24 09:23:17 +08:00
from . . utils import multi_gpu_test
2024-03-26 09:09:31 +08:00
MODEL_PATH = " THUDM/chatglm3-6b "
PROMPT_TEMPLATE = """ I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. \n " \n ##Instruction: \n concert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key. \n Table singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key. \n Table concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key. \n Table singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key. \n The Stadium_ID of concert is the foreign key of Stadium_ID of stadium. \n The Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer. \n The concert_ID of singer_in_concert is the foreign key of concert_ID of concert. \n \n ###Input: \n {query} \n \n ###Response: """ # noqa: E501
2024-11-24 09:23:17 +08:00
EXPECTED_LORA_OUTPUT = [
" SELECT count(*) FROM singer " ,
" SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = ' France ' " , # noqa: E501
" SELECT name , country , age FROM singer ORDER BY age " ,
]
2024-03-26 09:09:31 +08:00
2024-06-15 12:45:31 +08:00
def do_sample ( llm : vllm . LLM , lora_path : str , lora_id : int ) - > List [ str ] :
2024-03-26 09:09:31 +08:00
prompts = [
PROMPT_TEMPLATE . format ( query = " How many singers do we have? " ) ,
PROMPT_TEMPLATE . format (
query =
" What is the average, minimum, and maximum age of all singers from France? " # noqa: E501
) ,
PROMPT_TEMPLATE . format (
query =
" Show name, country, age for all singers ordered by age from the oldest to the youngest. " # noqa: E501
) ,
]
sampling_params = vllm . SamplingParams ( temperature = 0 , max_tokens = 32 )
outputs = llm . generate (
prompts ,
sampling_params ,
lora_request = LoRARequest ( str ( lora_id ) , lora_id , lora_path )
if lora_id else None )
# Print the outputs.
2024-06-15 12:45:31 +08:00
generated_texts : List [ str ] = [ ]
2024-03-26 09:09:31 +08:00
for output in outputs :
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text . strip ( )
generated_texts . append ( generated_text )
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
return generated_texts
2025-02-06 23:02:51 +05:30
@pytest.fixture ( autouse = True )
def v1 ( run_with_both_engines_lora ) :
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.skip_v1
2024-11-24 09:23:17 +08:00
@fork_new_process_for_each_test
2024-03-26 09:09:31 +08:00
def test_chatglm3_lora ( chatglm3_lora_files ) :
llm = vllm . LLM ( MODEL_PATH ,
max_model_len = 1024 ,
enable_lora = True ,
max_loras = 4 ,
max_lora_rank = 64 ,
2024-11-24 09:23:17 +08:00
tensor_parallel_size = 1 ,
2024-12-10 21:09:20 -05:00
trust_remote_code = True ,
enable_chunked_prefill = True )
2024-03-26 09:09:31 +08:00
2024-11-24 09:23:17 +08:00
output1 = do_sample ( llm , chatglm3_lora_files , lora_id = 1 )
for i in range ( len ( EXPECTED_LORA_OUTPUT ) ) :
assert output1 [ i ] == EXPECTED_LORA_OUTPUT [ i ]
output2 = do_sample ( llm , chatglm3_lora_files , lora_id = 2 )
for i in range ( len ( EXPECTED_LORA_OUTPUT ) ) :
assert output2 [ i ] == EXPECTED_LORA_OUTPUT [ i ]
2024-03-26 09:09:31 +08:00
2025-02-06 23:02:51 +05:30
@pytest.mark.skip_v1
2024-11-24 09:23:17 +08:00
@multi_gpu_test ( num_gpus = 4 )
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4 ( chatglm3_lora_files ) :
llm = vllm . LLM ( MODEL_PATH ,
max_model_len = 1024 ,
enable_lora = True ,
max_loras = 4 ,
max_lora_rank = 64 ,
tensor_parallel_size = 4 ,
trust_remote_code = True ,
2024-12-10 21:09:20 -05:00
fully_sharded_loras = False ,
enable_chunked_prefill = True )
2024-11-24 09:23:17 +08:00
output1 = do_sample ( llm , chatglm3_lora_files , lora_id = 1 )
for i in range ( len ( EXPECTED_LORA_OUTPUT ) ) :
assert output1 [ i ] == EXPECTED_LORA_OUTPUT [ i ]
output2 = do_sample ( llm , chatglm3_lora_files , lora_id = 2 )
for i in range ( len ( EXPECTED_LORA_OUTPUT ) ) :
assert output2 [ i ] == EXPECTED_LORA_OUTPUT [ i ]
2025-02-06 23:02:51 +05:30
@pytest.mark.skip_v1
2024-11-24 09:23:17 +08:00
@multi_gpu_test ( num_gpus = 4 )
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras ( chatglm3_lora_files ) :
llm = vllm . LLM ( MODEL_PATH ,
max_model_len = 1024 ,
enable_lora = True ,
max_loras = 4 ,
max_lora_rank = 64 ,
tensor_parallel_size = 4 ,
trust_remote_code = True ,
2024-12-10 21:09:20 -05:00
fully_sharded_loras = True ,
enable_chunked_prefill = True )
2024-03-26 09:09:31 +08:00
output1 = do_sample ( llm , chatglm3_lora_files , lora_id = 1 )
2024-11-24 09:23:17 +08:00
for i in range ( len ( EXPECTED_LORA_OUTPUT ) ) :
assert output1 [ i ] == EXPECTED_LORA_OUTPUT [ i ]
2024-03-26 09:09:31 +08:00
output2 = do_sample ( llm , chatglm3_lora_files , lora_id = 2 )
2024-11-24 09:23:17 +08:00
for i in range ( len ( EXPECTED_LORA_OUTPUT ) ) :
assert output2 [ i ] == EXPECTED_LORA_OUTPUT [ i ]