[Core] Optimize sampler get_logprobs (#4594)

This commit is contained in:
SangBin Cho 2024-05-09 00:42:28 +09:00 committed by GitHub
parent cc466a3290
commit d7740ea4dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -782,13 +782,14 @@ def _get_logprobs(
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
dim=-1)
top_logprobs = top_logprobs.cpu()
top_token_ids = top_token_ids.cpu()
else:
top_logprobs, top_token_ids = None, None
selected_logprobs = selected_logprobs.cpu()
ranks = ranks.cpu()
selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.to('cpu')
if top_logprobs is not None and top_token_ids is not None:
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
@ -828,37 +829,48 @@ def _get_prompt_logprob_if_needed(
# Find prompt logprobs
prompt_logprobs: Optional[PromptLogprobs] = None
if (is_prompt and sampling_params.prompt_logprobs is not None):
if is_prompt and sampling_params.prompt_logprobs is not None:
prompt_logprobs = []
num_logprobs = sampling_params.prompt_logprobs
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
for token_id in next_prompt_tokens:
# Pre-select indexes and create a list. It is faster than calling .item
# repetitively.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_prompt_tokens)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_prompt_tokens)].tolist()
for idx, token_id in enumerate(next_prompt_tokens):
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
token_id: (selected_logprobs[selected_logprobs_idx].item(),
ranks[selected_logprobs_idx].item())
token_id: (selected_logprob_items[idx], rank_items[idx])
}
# Add top K prompt logprobs along with its rank.
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip(
top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
zip(
top_logprobs[
top_logprob_idx, :num_logprobs].tolist(),
# This is ranks. Since top_logprob is sorted,
# we can just use a range here.
range(1, num_logprobs + 1))))
top_ids = top_token_ids[
top_logprob_idx, :num_logprobs].tolist()
top_probs = top_logprobs[
top_logprob_idx, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
prompt_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
})
prompt_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
})
# + 1 to go to the next prompt token.
top_logprob_idx += 1
selected_logprobs_idx += 1
# + len(next_prompt_tokens) to go to the next prompt.
selected_logprobs_idx += len(next_prompt_tokens)
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
@ -874,47 +886,54 @@ def _get_sampled_logprob_if_needed(
):
"""Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
num_logprobs = seq_group.sampling_params.logprobs or 0
sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample:
assert len(next_token_ids) > 0
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
# Calculate the sample logprob of the real sampled tokens.
# Use tuple here for performance (to use to_list()).
# token_id: (logprob, rank_from_vocab)
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
next_token_id:
(selected_logprobs[selected_logprobs_idx].item(),
ranks[selected_logprobs_idx].item())
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
for idx, (next_token_id,
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id: (selected_logprob_items[idx], rank_items[idx])
}
# +1 to go to the next sampled token. Note that
# selected_logprobs can contain duplicates unlike top_logprobs
# when beam search is enabled.
selected_logprobs_idx += 1
# Get top K logprobs.
if num_logprobs > 0:
top_ids = top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
})
# Second, add top K logprobs along with its rank.
if num_logprobs >= 0:
sampled_logprobs_dict.update(
zip(
top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
zip(
top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range(1, num_logprobs + 1))))
sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})
# There are len(seq_ids) number of sampled tokens for the current
# sequence group in top_logprobs. Jump to the next seq_group.
# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
# steps, which has len(seq_ids) tokens per sequence group.
# Iterate to the next sequence group in a batch.
selected_logprobs_idx += len(next_token_ids)
# Iterate to the next sequence group in a batch.
top_logprob_idx += len(seq_ids)
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx