[Misc]add coding benchmark for speculative decoding (#15303)
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
This commit is contained in:
parent
4ae17bf1e2
commit
e7f720ea56
@ -715,3 +715,66 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Instruct Coder Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InstructCoderDataset(HuggingFaceDataset):
|
||||
"""
|
||||
InstructCoder Dataset.
|
||||
https://huggingface.co/datasets/likaixin/InstructCoder
|
||||
|
||||
InstructCoder is the dataset designed for general code editing.
|
||||
It consists of 114,239 instruction-input-output triplets,
|
||||
and covers multiple distinct code editing scenario.
|
||||
"""
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
|
||||
DEFAULT_NUM_REQUESTS = 1000
|
||||
INSTRUCT_CODER_DATASET_PATH = "likaixin/InstructCoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if self.dataset_path != self.INSTRUCT_CODER_DATASET_PATH:
|
||||
raise ValueError(f"Only support likaixin/InstructCoder dataset.\
|
||||
This data path {self.dataset_path} is not valid.")
|
||||
if self.dataset_subset is None and self.dataset_split != "train":
|
||||
raise ValueError("Dataset split must be 'train'.")
|
||||
|
||||
def load_data(self) -> None:
|
||||
dataset = load_dataset(
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
)
|
||||
self.data = dataset.shuffle(seed=self.random_seed)
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = f"{item['instruction']}:\n{item['input']}"
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
@ -53,8 +53,9 @@ except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -588,9 +589,14 @@ def main(args: argparse.Namespace):
|
||||
elif args.dataset_name == "hf":
|
||||
# Choose between VisionArenaDataset
|
||||
# and HuggingFaceDataset based on provided parameters.
|
||||
dataset_class = (VisionArenaDataset if args.dataset_path
|
||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
||||
and args.hf_subset is None else HuggingFaceDataset)
|
||||
dataset_class = HuggingFaceDataset
|
||||
if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
|
||||
assert args.hf_subset is None, "VisionArenaDataset needs hf_subset to be None." #noqa: E501
|
||||
dataset_class = VisionArenaDataset
|
||||
elif args.dataset_path == "likaixin/InstructCoder":
|
||||
dataset_class = InstructCoderDataset
|
||||
args.hf_split = "train"
|
||||
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
|
@ -12,8 +12,9 @@ from typing import Any, Optional, Union
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
@ -300,6 +301,7 @@ def get_requests(args, tokenizer):
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
@ -317,17 +319,21 @@ def get_requests(args, tokenizer):
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.backend != "vllm-chat":
|
||||
raise ValueError(
|
||||
"hf datasets only are supported by vllm-chat backend")
|
||||
# Choose between VisionArenaDataset and HuggingFaceDataset based on
|
||||
# provided parameters.
|
||||
dataset_cls = (VisionArenaDataset if args.dataset_path
|
||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
||||
and args.hf_subset is None else HuggingFaceDataset)
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
|
||||
if args.args.backend == "vllm-chat":
|
||||
raise ValueError(
|
||||
"hf datasets only are supported by vllm-chat backend")
|
||||
# Choose between VisionArenaDataset and HuggingFaceDataset based on
|
||||
# provided parameters.
|
||||
dataset_cls = (VisionArenaDataset if args.dataset_path
|
||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
||||
and args.hf_subset is None else HuggingFaceDataset)
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path == "likaixin/InstructCoder":
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
@ -462,9 +468,14 @@ def validate_args(args):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf" and args.backend != "vllm-chat":
|
||||
raise ValueError(
|
||||
"When --dataset-name is 'hf', backend must be 'vllm-chat'")
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
|
||||
assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path == "likaixin/InstructCoder":
|
||||
assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user