[Bugfix] Add random_seed to sample_hf_requests in benchmark_serving script (#9013)
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
eca2c5f7c0
commit
d65049daab
@ -202,6 +202,7 @@ def sample_hf_requests(
|
|||||||
dataset_split: str,
|
dataset_split: str,
|
||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
random_seed: int,
|
||||||
fixed_output_len: Optional[int] = None,
|
fixed_output_len: Optional[int] = None,
|
||||||
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
||||||
dataset = load_dataset(dataset_path,
|
dataset = load_dataset(dataset_path,
|
||||||
@ -210,8 +211,8 @@ def sample_hf_requests(
|
|||||||
streaming=True)
|
streaming=True)
|
||||||
assert "conversations" in dataset.features, (
|
assert "conversations" in dataset.features, (
|
||||||
"HF Dataset must have 'conversations' column.")
|
"HF Dataset must have 'conversations' column.")
|
||||||
filtered_dataset = dataset.shuffle().filter(
|
filter_func = lambda x: len(x["conversations"]) >= 2
|
||||||
lambda x: len(x["conversations"]) >= 2)
|
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
|
||||||
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
||||||
Collection[str]]]] = []
|
Collection[str]]]] = []
|
||||||
for data in filtered_dataset:
|
for data in filtered_dataset:
|
||||||
@ -646,6 +647,7 @@ def main(args: argparse.Namespace):
|
|||||||
dataset_split=args.hf_split,
|
dataset_split=args.hf_split,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
random_seed=args.seed,
|
||||||
fixed_output_len=args.hf_output_len,
|
fixed_output_len=args.hf_output_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user