[Benchmark] Add AIMO Dataset to Benchmark (#15955)
Signed-off-by: Ziji Shi <shi.ziji.sm@gmail.com> Signed-off-by: StevenShi-23 <shi.ziji.sm@gmail.com>
This commit is contained in:
parent
57a810db9c
commit
06f21ce7a5
@ -752,3 +752,52 @@ class InstructCoderDataset(HuggingFaceDataset):
|
|||||||
))
|
))
|
||||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# AIMO Dataset Implementation
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AIMODataset(HuggingFaceDataset):
|
||||||
|
"""
|
||||||
|
Dataset class for processing a AIMO dataset with reasoning questions.
|
||||||
|
"""
|
||||||
|
SUPPORTED_DATASET_PATHS = {
|
||||||
|
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
|
||||||
|
"AI-MO/NuminaMath-CoT"
|
||||||
|
}
|
||||||
|
|
||||||
|
def sample(self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int,
|
||||||
|
output_len: Optional[int] = None,
|
||||||
|
**kwargs) -> list:
|
||||||
|
sampled_requests = []
|
||||||
|
dynamic_output = output_len is None
|
||||||
|
|
||||||
|
for item in self.data:
|
||||||
|
if len(sampled_requests) >= num_requests:
|
||||||
|
break
|
||||||
|
prompt, completion = item['problem'], item["solution"]
|
||||||
|
|
||||||
|
prompt_ids = tokenizer(prompt).input_ids
|
||||||
|
completion_ids = tokenizer(completion).input_ids
|
||||||
|
prompt_len = len(prompt_ids)
|
||||||
|
completion_len = len(completion_ids)
|
||||||
|
output_len = completion_len if dynamic_output else output_len
|
||||||
|
assert isinstance(output_len, int) and output_len > 0
|
||||||
|
if dynamic_output and not is_valid_sequence(prompt_len,
|
||||||
|
completion_len,
|
||||||
|
max_prompt_len=2048,
|
||||||
|
max_total_len=32000):
|
||||||
|
continue
|
||||||
|
sampled_requests.append(
|
||||||
|
SampleRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
expected_output_len=output_len,
|
||||||
|
multi_modal_data=None,
|
||||||
|
))
|
||||||
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
|
return sampled_requests
|
||||||
|
@ -49,10 +49,11 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
|
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||||
HuggingFaceDataset, InstructCoderDataset,
|
ConversationDataset, HuggingFaceDataset,
|
||||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
InstructCoderDataset, RandomDataset,
|
||||||
SonnetDataset, VisionArenaDataset)
|
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||||
|
VisionArenaDataset)
|
||||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||||
|
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||||
@ -595,6 +596,9 @@ def main(args: argparse.Namespace):
|
|||||||
args.hf_split = "train"
|
args.hf_split = "train"
|
||||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||||
dataset_class = ConversationDataset
|
dataset_class = ConversationDataset
|
||||||
|
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = AIMODataset
|
||||||
|
args.hf_split = "train"
|
||||||
else:
|
else:
|
||||||
supported_datasets = set([
|
supported_datasets = set([
|
||||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||||
@ -610,10 +614,10 @@ def main(args: argparse.Namespace):
|
|||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
dataset_subset=args.hf_subset,
|
dataset_subset=args.hf_subset,
|
||||||
dataset_split=args.hf_split,
|
dataset_split=args.hf_split,
|
||||||
|
random_seed=args.seed,
|
||||||
).sample(
|
).sample(
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
random_seed=args.seed,
|
|
||||||
output_len=args.hf_output_len,
|
output_len=args.hf_output_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user