# SPDX-License-Identifier: Apache-2.0 from vllm import LLM, SamplingParams from vllm.utils import FlexibleArgumentParser def main(): parser = FlexibleArgumentParser(description='AQLM examples') parser.add_argument('--model', '-m', type=str, default=None, help='model path, as for HF') parser.add_argument('--choice', '-c', type=int, default=0, help='known good models by index, [0-4]') parser.add_argument('--tensor-parallel-size', '-t', type=int, default=1, help='tensor parallel size') args = parser.parse_args() models = [ "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf", "ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf", "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf", "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf", ] model = LLM(args.model if args.model is not None else models[args.choice], tensor_parallel_size=args.tensor_parallel_size) sampling_params = SamplingParams(max_tokens=100, temperature=0) outputs = model.generate("Hello my name is", sampling_params=sampling_params) print(outputs[0].outputs[0].text) if __name__ == '__main__': main()