[Core] Optimize sampler get_logprobs (#4594)

This commit is contained in:
SangBin Cho 2024-05-09 00:42:28 +09:00 committed by GitHub
parent cc466a3290
commit d7740ea4dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -782,13 +782,14 @@ def _get_logprobs(
top_logprobs, top_token_ids = torch.topk(logprobs, top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs, largest_num_logprobs,
dim=-1) dim=-1)
top_logprobs = top_logprobs.cpu()
top_token_ids = top_token_ids.cpu()
else: else:
top_logprobs, top_token_ids = None, None top_logprobs, top_token_ids = None, None
selected_logprobs = selected_logprobs.cpu() selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.cpu() ranks = ranks.to('cpu')
if top_logprobs is not None and top_token_ids is not None:
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')
# Find prompt/sample logprobs. # Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
@ -828,37 +829,48 @@ def _get_prompt_logprob_if_needed(
# Find prompt logprobs # Find prompt logprobs
prompt_logprobs: Optional[PromptLogprobs] = None prompt_logprobs: Optional[PromptLogprobs] = None
if (is_prompt and sampling_params.prompt_logprobs is not None): if is_prompt and sampling_params.prompt_logprobs is not None:
prompt_logprobs = [] prompt_logprobs = []
num_logprobs = sampling_params.prompt_logprobs num_logprobs = sampling_params.prompt_logprobs
next_prompt_tokens = _get_next_prompt_tokens(seq_group) next_prompt_tokens = _get_next_prompt_tokens(seq_group)
for token_id in next_prompt_tokens: # Pre-select indexes and create a list. It is faster than calling .item
# repetitively.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_prompt_tokens)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_prompt_tokens)].tolist()
for idx, token_id in enumerate(next_prompt_tokens):
# Calculate the prompt logprob of the real prompt tokens. # Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)} # {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
token_id: (selected_logprobs[selected_logprobs_idx].item(), token_id: (selected_logprob_items[idx], rank_items[idx])
ranks[selected_logprobs_idx].item())
} }
# Add top K prompt logprobs along with its rank. # Add top K prompt logprobs along with its rank.
if num_logprobs > 0: if num_logprobs > 0:
prompt_logprobs_dict.update( top_ids = top_token_ids[
zip( top_logprob_idx, :num_logprobs].tolist()
top_token_ids[top_logprob_idx, :num_logprobs].tolist(), top_probs = top_logprobs[
zip( top_logprob_idx, :num_logprobs].tolist()
top_logprobs[ # Top K is already sorted by rank, so we can use 1 ~
top_logprob_idx, :num_logprobs].tolist(), # num_logprobs + 1 for rank.
# This is ranks. Since top_logprob is sorted, top_ranks = range(1, num_logprobs + 1)
# we can just use a range here. prompt_logprobs_dict.update({
range(1, num_logprobs + 1)))) top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
})
prompt_logprobs.append({ prompt_logprobs.append({
token_id: Logprob(*logprob_and_rank) token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in prompt_logprobs_dict.items() for token_id, logprob_and_rank in prompt_logprobs_dict.items()
}) })
# + 1 to go to the next prompt token. # + 1 to go to the next prompt token.
top_logprob_idx += 1 top_logprob_idx += 1
selected_logprobs_idx += 1
# + len(next_prompt_tokens) to go to the next prompt.
selected_logprobs_idx += len(next_prompt_tokens)
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
@ -874,47 +886,54 @@ def _get_sampled_logprob_if_needed(
): ):
"""Compute the sample logprob if needed.""" """Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs num_logprobs = seq_group.sampling_params.logprobs or 0
if num_logprobs is None:
num_logprobs = 0
sampled_logprobs: SampleLogprobs = [] sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample: if seq_group.do_sample:
assert len(next_token_ids) > 0 assert len(next_token_ids) > 0
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids): # Pre-select items from tensor. tolist() is faster than repetitive
# Calculate the sample logprob of the real sampled tokens. # `.item()` calls.
# Use tuple here for performance (to use to_list()). selected_logprob_items = selected_logprobs[
# token_id: (logprob, rank_from_vocab) selected_logprobs_idx:selected_logprobs_idx +
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = { len(next_token_ids)].tolist()
next_token_id: rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
(selected_logprobs[selected_logprobs_idx].item(), len(next_token_ids)].tolist()
ranks[selected_logprobs_idx].item()) for idx, (next_token_id,
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id: (selected_logprob_items[idx], rank_items[idx])
} }
# +1 to go to the next sampled token. Note that # Get top K logprobs.
# selected_logprobs can contain duplicates unlike top_logprobs if num_logprobs > 0:
# when beam search is enabled. top_ids = top_token_ids[top_logprob_idx +
selected_logprobs_idx += 1 parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
})
# Second, add top K logprobs along with its rank.
if num_logprobs >= 0:
sampled_logprobs_dict.update(
zip(
top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
zip(
top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range(1, num_logprobs + 1))))
sampled_logprobs.append({ sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank) token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in for token_id, logprob_and_rank in
sampled_logprobs_dict.items() sampled_logprobs_dict.items()
}) })
# There are len(seq_ids) number of sampled tokens for the current
# sequence group in top_logprobs. Jump to the next seq_group. # NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
# steps, which has len(seq_ids) tokens per sequence group.
# Iterate to the next sequence group in a batch.
selected_logprobs_idx += len(next_token_ids)
# Iterate to the next sequence group in a batch.
top_logprob_idx += len(seq_ids) top_logprob_idx += len(seq_ids)
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx return sampled_logprobs, top_logprob_idx, selected_logprobs_idx