diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index c2fbe2bb..1d61485e 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -752,3 +752,52 @@ class InstructCoderDataset(HuggingFaceDataset): )) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# AIMO Dataset Implementation +# ----------------------------------------------------------------------------- + + +class AIMODataset(HuggingFaceDataset): + """ + Dataset class for processing a AIMO dataset with reasoning questions. + """ + SUPPORTED_DATASET_PATHS = { + "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT" + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs) -> list: + sampled_requests = [] + dynamic_output = output_len is None + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt, completion = item['problem'], item["solution"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence(prompt_len, + completion_len, + max_prompt_len=2048, + max_total_len=32000): + continue + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=None, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index ec2ed1a1..59648222 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -49,10 +49,11 @@ try: except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from benchmark_dataset import (BurstGPTDataset, ConversationDataset, - HuggingFaceDataset, InstructCoderDataset, - RandomDataset, SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) +from benchmark_dataset import (AIMODataset, BurstGPTDataset, + ConversationDataset, HuggingFaceDataset, + InstructCoderDataset, RandomDataset, + SampleRequest, ShareGPTDataset, SonnetDataset, + VisionArenaDataset) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -595,6 +596,9 @@ def main(args: argparse.Namespace): args.hf_split = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_class = ConversationDataset + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_class = AIMODataset + args.hf_split = "train" else: supported_datasets = set([ dataset_name for cls in HuggingFaceDataset.__subclasses__() @@ -610,10 +614,10 @@ def main(args: argparse.Namespace): dataset_path=args.dataset_path, dataset_subset=args.hf_subset, dataset_split=args.hf_split, + random_seed=args.seed, ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, - random_seed=args.seed, output_len=args.hf_output_len, )