support trust_remote_code in benchmark (#518)
This commit is contained in:
parent
16c3e295a8
commit
cf21a9bd5c
@ -67,12 +67,14 @@ def run_vllm(
|
|||||||
seed: int,
|
seed: int,
|
||||||
n: int,
|
n: int,
|
||||||
use_beam_search: bool,
|
use_beam_search: bool,
|
||||||
|
trust_remote_code: bool,
|
||||||
) -> float:
|
) -> float:
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
@ -106,9 +108,10 @@ def run_hf(
|
|||||||
n: int,
|
n: int,
|
||||||
use_beam_search: bool,
|
use_beam_search: bool,
|
||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
|
trust_remote_code: bool,
|
||||||
) -> float:
|
) -> float:
|
||||||
assert not use_beam_search
|
assert not use_beam_search
|
||||||
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
|
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||||
if llm.config.model_type == "llama":
|
if llm.config.model_type == "llama":
|
||||||
# To enable padding in the HF backend.
|
# To enable padding in the HF backend.
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
@ -161,13 +164,13 @@ def main(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
|
|
||||||
# Sample the requests.
|
# Sample the requests.
|
||||||
tokenizer = get_tokenizer(args.tokenizer)
|
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
elapsed_time = run_vllm(
|
elapsed_time = run_vllm(
|
||||||
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
||||||
args.seed, args.n, args.use_beam_search)
|
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@ -199,6 +202,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
||||||
help="Maximum batch size for HF backend.")
|
help="Maximum batch size for HF backend.")
|
||||||
|
parser.add_argument('--trust-remote-code',
|
||||||
|
action='store_true',
|
||||||
|
help='trust remote code from huggingface')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user