[Core] Optimize sampler get_logprobs (#4594)
This commit is contained in:
parent
cc466a3290
commit
d7740ea4dc
@ -782,13 +782,14 @@ def _get_logprobs(
|
|||||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||||
largest_num_logprobs,
|
largest_num_logprobs,
|
||||||
dim=-1)
|
dim=-1)
|
||||||
top_logprobs = top_logprobs.cpu()
|
|
||||||
top_token_ids = top_token_ids.cpu()
|
|
||||||
else:
|
else:
|
||||||
top_logprobs, top_token_ids = None, None
|
top_logprobs, top_token_ids = None, None
|
||||||
|
|
||||||
selected_logprobs = selected_logprobs.cpu()
|
selected_logprobs = selected_logprobs.to('cpu')
|
||||||
ranks = ranks.cpu()
|
ranks = ranks.to('cpu')
|
||||||
|
if top_logprobs is not None and top_token_ids is not None:
|
||||||
|
top_logprobs = top_logprobs.to('cpu')
|
||||||
|
top_token_ids = top_token_ids.to('cpu')
|
||||||
|
|
||||||
# Find prompt/sample logprobs.
|
# Find prompt/sample logprobs.
|
||||||
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
||||||
@ -828,37 +829,48 @@ def _get_prompt_logprob_if_needed(
|
|||||||
|
|
||||||
# Find prompt logprobs
|
# Find prompt logprobs
|
||||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
if (is_prompt and sampling_params.prompt_logprobs is not None):
|
if is_prompt and sampling_params.prompt_logprobs is not None:
|
||||||
prompt_logprobs = []
|
prompt_logprobs = []
|
||||||
num_logprobs = sampling_params.prompt_logprobs
|
num_logprobs = sampling_params.prompt_logprobs
|
||||||
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
||||||
for token_id in next_prompt_tokens:
|
# Pre-select indexes and create a list. It is faster than calling .item
|
||||||
|
# repetitively.
|
||||||
|
selected_logprob_items = selected_logprobs[
|
||||||
|
selected_logprobs_idx:selected_logprobs_idx +
|
||||||
|
len(next_prompt_tokens)].tolist()
|
||||||
|
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
||||||
|
len(next_prompt_tokens)].tolist()
|
||||||
|
|
||||||
|
for idx, token_id in enumerate(next_prompt_tokens):
|
||||||
# Calculate the prompt logprob of the real prompt tokens.
|
# Calculate the prompt logprob of the real prompt tokens.
|
||||||
# Use tuple here for performance (to use to_list()).
|
|
||||||
# {token_id: (logprob, rank_from_vocab)}
|
# {token_id: (logprob, rank_from_vocab)}
|
||||||
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
||||||
token_id: (selected_logprobs[selected_logprobs_idx].item(),
|
token_id: (selected_logprob_items[idx], rank_items[idx])
|
||||||
ranks[selected_logprobs_idx].item())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add top K prompt logprobs along with its rank.
|
# Add top K prompt logprobs along with its rank.
|
||||||
if num_logprobs > 0:
|
if num_logprobs > 0:
|
||||||
prompt_logprobs_dict.update(
|
top_ids = top_token_ids[
|
||||||
zip(
|
top_logprob_idx, :num_logprobs].tolist()
|
||||||
top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
|
top_probs = top_logprobs[
|
||||||
zip(
|
top_logprob_idx, :num_logprobs].tolist()
|
||||||
top_logprobs[
|
# Top K is already sorted by rank, so we can use 1 ~
|
||||||
top_logprob_idx, :num_logprobs].tolist(),
|
# num_logprobs + 1 for rank.
|
||||||
# This is ranks. Since top_logprob is sorted,
|
top_ranks = range(1, num_logprobs + 1)
|
||||||
# we can just use a range here.
|
prompt_logprobs_dict.update({
|
||||||
range(1, num_logprobs + 1))))
|
top_id: (top_prob, rank)
|
||||||
|
for top_id, top_prob, rank in zip(top_ids, top_probs,
|
||||||
|
top_ranks)
|
||||||
|
})
|
||||||
prompt_logprobs.append({
|
prompt_logprobs.append({
|
||||||
token_id: Logprob(*logprob_and_rank)
|
token_id: Logprob(*logprob_and_rank)
|
||||||
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
|
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
|
||||||
})
|
})
|
||||||
# + 1 to go to the next prompt token.
|
# + 1 to go to the next prompt token.
|
||||||
top_logprob_idx += 1
|
top_logprob_idx += 1
|
||||||
selected_logprobs_idx += 1
|
|
||||||
|
# + len(next_prompt_tokens) to go to the next prompt.
|
||||||
|
selected_logprobs_idx += len(next_prompt_tokens)
|
||||||
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
|
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
|
||||||
|
|
||||||
|
|
||||||
@ -874,47 +886,54 @@ def _get_sampled_logprob_if_needed(
|
|||||||
):
|
):
|
||||||
"""Compute the sample logprob if needed."""
|
"""Compute the sample logprob if needed."""
|
||||||
seq_ids = seq_group.seq_ids
|
seq_ids = seq_group.seq_ids
|
||||||
num_logprobs = seq_group.sampling_params.logprobs
|
num_logprobs = seq_group.sampling_params.logprobs or 0
|
||||||
if num_logprobs is None:
|
|
||||||
num_logprobs = 0
|
|
||||||
sampled_logprobs: SampleLogprobs = []
|
sampled_logprobs: SampleLogprobs = []
|
||||||
next_token_ids, parent_seq_ids = sample_result
|
next_token_ids, parent_seq_ids = sample_result
|
||||||
|
|
||||||
if seq_group.do_sample:
|
if seq_group.do_sample:
|
||||||
assert len(next_token_ids) > 0
|
assert len(next_token_ids) > 0
|
||||||
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
|
# Pre-select items from tensor. tolist() is faster than repetitive
|
||||||
# Calculate the sample logprob of the real sampled tokens.
|
# `.item()` calls.
|
||||||
# Use tuple here for performance (to use to_list()).
|
selected_logprob_items = selected_logprobs[
|
||||||
# token_id: (logprob, rank_from_vocab)
|
selected_logprobs_idx:selected_logprobs_idx +
|
||||||
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
len(next_token_ids)].tolist()
|
||||||
next_token_id:
|
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
||||||
(selected_logprobs[selected_logprobs_idx].item(),
|
len(next_token_ids)].tolist()
|
||||||
ranks[selected_logprobs_idx].item())
|
for idx, (next_token_id,
|
||||||
|
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
|
||||||
|
# Get the logprob of a sampled token.
|
||||||
|
sampled_logprobs_dict = {
|
||||||
|
next_token_id: (selected_logprob_items[idx], rank_items[idx])
|
||||||
}
|
}
|
||||||
# +1 to go to the next sampled token. Note that
|
# Get top K logprobs.
|
||||||
# selected_logprobs can contain duplicates unlike top_logprobs
|
if num_logprobs > 0:
|
||||||
# when beam search is enabled.
|
top_ids = top_token_ids[top_logprob_idx +
|
||||||
selected_logprobs_idx += 1
|
parent_id, :num_logprobs].tolist()
|
||||||
|
top_probs = top_logprobs[top_logprob_idx +
|
||||||
|
parent_id, :num_logprobs].tolist()
|
||||||
|
# Top K is already sorted by rank, so we can use 1 ~
|
||||||
|
# num_logprobs + 1 for rank.
|
||||||
|
top_ranks = range(1, num_logprobs + 1)
|
||||||
|
sampled_logprobs_dict.update({
|
||||||
|
top_id: (top_prob, rank)
|
||||||
|
for top_id, top_prob, rank in zip(top_ids, top_probs,
|
||||||
|
top_ranks)
|
||||||
|
})
|
||||||
|
|
||||||
# Second, add top K logprobs along with its rank.
|
|
||||||
if num_logprobs >= 0:
|
|
||||||
sampled_logprobs_dict.update(
|
|
||||||
zip(
|
|
||||||
top_token_ids[top_logprob_idx +
|
|
||||||
parent_id, :num_logprobs].tolist(),
|
|
||||||
zip(
|
|
||||||
top_logprobs[top_logprob_idx +
|
|
||||||
parent_id, :num_logprobs].tolist(),
|
|
||||||
# This is rank. Since top_logprob is sorted, we
|
|
||||||
# can just use a range here.
|
|
||||||
range(1, num_logprobs + 1))))
|
|
||||||
sampled_logprobs.append({
|
sampled_logprobs.append({
|
||||||
token_id: Logprob(*logprob_and_rank)
|
token_id: Logprob(*logprob_and_rank)
|
||||||
for token_id, logprob_and_rank in
|
for token_id, logprob_and_rank in
|
||||||
sampled_logprobs_dict.items()
|
sampled_logprobs_dict.items()
|
||||||
})
|
})
|
||||||
# There are len(seq_ids) number of sampled tokens for the current
|
|
||||||
# sequence group in top_logprobs. Jump to the next seq_group.
|
# NOTE: This part of code is not intuitive. `selected_logprobs` include
|
||||||
|
# logprobs for the current step, which has len(next_token_ids) tokens
|
||||||
|
# per sequence group. `logprobs` includes logprobs from the previous
|
||||||
|
# steps, which has len(seq_ids) tokens per sequence group.
|
||||||
|
|
||||||
|
# Iterate to the next sequence group in a batch.
|
||||||
|
selected_logprobs_idx += len(next_token_ids)
|
||||||
|
# Iterate to the next sequence group in a batch.
|
||||||
top_logprob_idx += len(seq_ids)
|
top_logprob_idx += len(seq_ids)
|
||||||
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
|
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user