[Benchmark] Refactor sample_requests in benchmark_throughput (#3613)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
819a309c0f
commit
b7782002e1
@ -29,22 +29,23 @@ def sample_requests(
|
|||||||
dataset = [(data["conversations"][0]["value"],
|
dataset = [(data["conversations"][0]["value"],
|
||||||
data["conversations"][1]["value"]) for data in dataset]
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Shuffle the dataset.
|
||||||
prompts = [prompt for prompt, _ in dataset]
|
random.shuffle(dataset)
|
||||||
prompt_token_ids = tokenizer(prompts).input_ids
|
|
||||||
completions = [completion for _, completion in dataset]
|
|
||||||
completion_token_ids = tokenizer(completions).input_ids
|
|
||||||
tokenized_dataset = []
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
output_len = len(completion_token_ids[i])
|
|
||||||
if fixed_output_len is not None:
|
|
||||||
output_len = fixed_output_len
|
|
||||||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
|
||||||
|
|
||||||
# Filter out too long sequences.
|
# Filter out sequences that are too long or too short
|
||||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||||
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
for i in range(len(dataset)):
|
||||||
|
if len(filtered_dataset) == num_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt = dataset[i][0]
|
||||||
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||||||
|
completion = dataset[i][1]
|
||||||
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
prompt_len = len(prompt_token_ids)
|
prompt_len = len(prompt_token_ids)
|
||||||
|
output_len = len(completion_token_ids
|
||||||
|
) if fixed_output_len is None else fixed_output_len
|
||||||
if prompt_len < 4 or output_len < 4:
|
if prompt_len < 4 or output_len < 4:
|
||||||
# Prune too short sequences.
|
# Prune too short sequences.
|
||||||
continue
|
continue
|
||||||
@ -53,9 +54,7 @@ def sample_requests(
|
|||||||
continue
|
continue
|
||||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||||
|
|
||||||
# Sample the requests.
|
return filtered_dataset
|
||||||
sampled_requests = random.sample(filtered_dataset, num_requests)
|
|
||||||
return sampled_requests
|
|
||||||
|
|
||||||
|
|
||||||
def run_vllm(
|
def run_vllm(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user