2025-02-02 14:58:18 -05:00
# SPDX-License-Identifier: Apache-2.0
2024-06-15 12:45:31 +08:00
from typing import List
2024-03-26 09:09:31 +08:00
import pytest
import vllm
2024-10-18 14:30:55 -07:00
from vllm . distributed import cleanup_dist_env_and_memory
2024-03-26 09:09:31 +08:00
from vllm . lora . request import LoRARequest
MODEL_PATH = " baichuan-inc/Baichuan-7B "
PROMPT_TEMPLATE = """ I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. \n " \n ##Instruction: \n concert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key. \n Table singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key. \n Table concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key. \n Table singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key. \n The Stadium_ID of concert is the foreign key of Stadium_ID of stadium. \n The Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer. \n The concert_ID of singer_in_concert is the foreign key of concert_ID of concert. \n \n ###Input: \n {query} \n \n ###Response: """ # noqa: E501
2024-06-15 12:45:31 +08:00
def do_sample ( llm : vllm . LLM , lora_path : str , lora_id : int ) - > List [ str ] :
2024-03-26 09:09:31 +08:00
prompts = [
PROMPT_TEMPLATE . format ( query = " How many singers do we have? " ) ,
PROMPT_TEMPLATE . format (
query =
" What is the average, minimum, and maximum age of all singers from France? " # noqa: E501
) ,
PROMPT_TEMPLATE . format (
query =
" Show name, country, age for all singers ordered by age from the oldest to the youngest. " # noqa: E501
) ,
]
print ( prompts )
sampling_params = vllm . SamplingParams ( temperature = 0 , max_tokens = 256 )
outputs = llm . generate (
prompts ,
sampling_params ,
lora_request = LoRARequest ( str ( lora_id ) , lora_id , lora_path )
if lora_id else None )
# Print the outputs.
2024-06-15 12:45:31 +08:00
generated_texts : List [ str ] = [ ]
2024-03-26 09:09:31 +08:00
for output in outputs :
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text . strip ( )
generated_texts . append ( generated_text )
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
return generated_texts
2025-02-06 23:02:51 +05:30
@pytest.fixture ( autouse = True )
def v1 ( run_with_both_engines_lora ) :
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
2024-03-26 09:09:31 +08:00
def test_baichuan_lora ( baichuan_lora_files ) :
llm = vllm . LLM ( MODEL_PATH ,
max_model_len = 1024 ,
enable_lora = True ,
max_loras = 4 ,
max_lora_rank = 64 ,
trust_remote_code = True )
expected_lora_output = [
" SELECT count(*) FROM singer " ,
" SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = ' France ' " , # noqa: E501
" SELECT name , country , age FROM singer ORDER BY age ASC " ,
]
output1 = do_sample ( llm , baichuan_lora_files , lora_id = 1 )
for i in range ( len ( expected_lora_output ) ) :
assert output1 [ i ] == expected_lora_output [ i ]
output2 = do_sample ( llm , baichuan_lora_files , lora_id = 2 )
for i in range ( len ( expected_lora_output ) ) :
assert output2 [ i ] == expected_lora_output [ i ]
2024-06-21 12:46:28 +08:00
@pytest.mark.parametrize ( " fully_sharded " , [ True , False ] )
2024-09-29 10:50:51 +08:00
def test_baichuan_tensor_parallel_equality ( baichuan_lora_files ,
num_gpus_available , fully_sharded ) :
if num_gpus_available < 4 :
pytest . skip ( f " Not enough GPUs for tensor parallelism { 4 } " )
2024-03-26 09:09:31 +08:00
llm_tp1 = vllm . LLM ( MODEL_PATH ,
enable_lora = True ,
max_num_seqs = 16 ,
max_loras = 4 ,
max_lora_rank = 64 ,
tensor_parallel_size = 1 ,
2024-06-21 12:46:28 +08:00
trust_remote_code = True ,
fully_sharded_loras = fully_sharded )
2024-03-26 09:09:31 +08:00
output_tp1 = do_sample ( llm_tp1 , baichuan_lora_files , lora_id = 1 )
del llm_tp1
2024-10-18 14:30:55 -07:00
cleanup_dist_env_and_memory ( )
2024-03-26 09:09:31 +08:00
llm_tp2 = vllm . LLM ( MODEL_PATH ,
enable_lora = True ,
max_num_seqs = 16 ,
max_loras = 4 ,
max_lora_rank = 64 ,
tensor_parallel_size = 2 ,
2024-06-21 12:46:28 +08:00
trust_remote_code = True ,
fully_sharded_loras = fully_sharded )
2024-03-26 09:09:31 +08:00
output_tp2 = do_sample ( llm_tp2 , baichuan_lora_files , lora_id = 2 )
del llm_tp2
2024-10-18 14:30:55 -07:00
cleanup_dist_env_and_memory ( )
2024-03-26 09:09:31 +08:00
assert output_tp1 == output_tp2
llm_tp4 = vllm . LLM ( MODEL_PATH ,
enable_lora = True ,
max_num_seqs = 16 ,
max_loras = 4 ,
max_lora_rank = 64 ,
tensor_parallel_size = 4 ,
2024-06-21 12:46:28 +08:00
trust_remote_code = True ,
fully_sharded_loras = fully_sharded )
2024-03-26 09:09:31 +08:00
output_tp4 = do_sample ( llm_tp4 , baichuan_lora_files , lora_id = 2 )
del llm_tp4
2024-10-18 14:30:55 -07:00
cleanup_dist_env_and_memory ( )
2024-03-26 09:09:31 +08:00
2024-06-21 12:46:28 +08:00
assert output_tp1 == output_tp4