support trust_remote_code in benchmark (#518)

This commit is contained in:
WRH 2023-07-20 08:02:40 +08:00 committed by GitHub
parent 16c3e295a8
commit cf21a9bd5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -67,12 +67,14 @@ def run_vllm(
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
) -> float:
llm = LLM(
model=model,
tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code
)
# Add the requests to the engine.
@ -106,9 +108,10 @@ def run_hf(
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
@ -161,13 +164,13 @@ def main(args: argparse.Namespace):
random.seed(args.seed)
# Sample the requests.
tokenizer = get_tokenizer(args.tokenizer)
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search)
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -199,6 +202,9 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
if args.backend == "vllm":