47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
import argparse
|
|
|
|
from vllm import LLM, SamplingParams
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser(description='AQLM examples')
|
|
|
|
parser.add_argument('--model',
|
|
'-m',
|
|
type=str,
|
|
default=None,
|
|
help='model path, as for HF')
|
|
parser.add_argument('--choice',
|
|
'-c',
|
|
type=int,
|
|
default=0,
|
|
help='known good models by index, [0-4]')
|
|
parser.add_argument('--tensor_parallel_size',
|
|
'-t',
|
|
type=int,
|
|
default=1,
|
|
help='tensor parallel size')
|
|
|
|
args = parser.parse_args()
|
|
|
|
models = [
|
|
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf",
|
|
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf",
|
|
"ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf",
|
|
"ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf",
|
|
"BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf",
|
|
]
|
|
|
|
model = LLM(args.model if args.model is not None else models[args.choice],
|
|
tensor_parallel_size=args.tensor_parallel_size)
|
|
|
|
sampling_params = SamplingParams(max_tokens=100, temperature=0)
|
|
outputs = model.generate("Hello my name is",
|
|
sampling_params=sampling_params)
|
|
print(outputs[0].outputs[0].text)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|