diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1f19d205..e52e350d 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -782,13 +782,14 @@ def _get_logprobs( top_logprobs, top_token_ids = torch.topk(logprobs, largest_num_logprobs, dim=-1) - top_logprobs = top_logprobs.cpu() - top_token_ids = top_token_ids.cpu() else: top_logprobs, top_token_ids = None, None - selected_logprobs = selected_logprobs.cpu() - ranks = ranks.cpu() + selected_logprobs = selected_logprobs.to('cpu') + ranks = ranks.to('cpu') + if top_logprobs is not None and top_token_ids is not None: + top_logprobs = top_logprobs.to('cpu') + top_token_ids = top_token_ids.to('cpu') # Find prompt/sample logprobs. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] @@ -828,37 +829,48 @@ def _get_prompt_logprob_if_needed( # Find prompt logprobs prompt_logprobs: Optional[PromptLogprobs] = None - if (is_prompt and sampling_params.prompt_logprobs is not None): + if is_prompt and sampling_params.prompt_logprobs is not None: prompt_logprobs = [] num_logprobs = sampling_params.prompt_logprobs next_prompt_tokens = _get_next_prompt_tokens(seq_group) - for token_id in next_prompt_tokens: + # Pre-select indexes and create a list. It is faster than calling .item + # repetitively. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + + for idx, token_id in enumerate(next_prompt_tokens): # Calculate the prompt logprob of the real prompt tokens. - # Use tuple here for performance (to use to_list()). # {token_id: (logprob, rank_from_vocab)} prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { - token_id: (selected_logprobs[selected_logprobs_idx].item(), - ranks[selected_logprobs_idx].item()) + token_id: (selected_logprob_items[idx], rank_items[idx]) } # Add top K prompt logprobs along with its rank. if num_logprobs > 0: - prompt_logprobs_dict.update( - zip( - top_token_ids[top_logprob_idx, :num_logprobs].tolist(), - zip( - top_logprobs[ - top_logprob_idx, :num_logprobs].tolist(), - # This is ranks. Since top_logprob is sorted, - # we can just use a range here. - range(1, num_logprobs + 1)))) + top_ids = top_token_ids[ + top_logprob_idx, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + prompt_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) prompt_logprobs.append({ token_id: Logprob(*logprob_and_rank) for token_id, logprob_and_rank in prompt_logprobs_dict.items() }) # + 1 to go to the next prompt token. top_logprob_idx += 1 - selected_logprobs_idx += 1 + + # + len(next_prompt_tokens) to go to the next prompt. + selected_logprobs_idx += len(next_prompt_tokens) return prompt_logprobs, top_logprob_idx, selected_logprobs_idx @@ -874,47 +886,54 @@ def _get_sampled_logprob_if_needed( ): """Compute the sample logprob if needed.""" seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.logprobs - if num_logprobs is None: - num_logprobs = 0 + num_logprobs = seq_group.sampling_params.logprobs or 0 sampled_logprobs: SampleLogprobs = [] next_token_ids, parent_seq_ids = sample_result if seq_group.do_sample: assert len(next_token_ids) > 0 - for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids): - # Calculate the sample logprob of the real sampled tokens. - # Use tuple here for performance (to use to_list()). - # token_id: (logprob, rank_from_vocab) - sampled_logprobs_dict: Dict[int, Tuple[float, int]] = { - next_token_id: - (selected_logprobs[selected_logprobs_idx].item(), - ranks[selected_logprobs_idx].item()) + # Pre-select items from tensor. tolist() is faster than repetitive + # `.item()` calls. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + for idx, (next_token_id, + parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)): + # Get the logprob of a sampled token. + sampled_logprobs_dict = { + next_token_id: (selected_logprob_items[idx], rank_items[idx]) } - # +1 to go to the next sampled token. Note that - # selected_logprobs can contain duplicates unlike top_logprobs - # when beam search is enabled. - selected_logprobs_idx += 1 + # Get top K logprobs. + if num_logprobs > 0: + top_ids = top_token_ids[top_logprob_idx + + parent_id, :num_logprobs].tolist() + top_probs = top_logprobs[top_logprob_idx + + parent_id, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + sampled_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) - # Second, add top K logprobs along with its rank. - if num_logprobs >= 0: - sampled_logprobs_dict.update( - zip( - top_token_ids[top_logprob_idx + - parent_id, :num_logprobs].tolist(), - zip( - top_logprobs[top_logprob_idx + - parent_id, :num_logprobs].tolist(), - # This is rank. Since top_logprob is sorted, we - # can just use a range here. - range(1, num_logprobs + 1)))) sampled_logprobs.append({ token_id: Logprob(*logprob_and_rank) for token_id, logprob_and_rank in sampled_logprobs_dict.items() }) - # There are len(seq_ids) number of sampled tokens for the current - # sequence group in top_logprobs. Jump to the next seq_group. + + # NOTE: This part of code is not intuitive. `selected_logprobs` include + # logprobs for the current step, which has len(next_token_ids) tokens + # per sequence group. `logprobs` includes logprobs from the previous + # steps, which has len(seq_ids) tokens per sequence group. + + # Iterate to the next sequence group in a batch. + selected_logprobs_idx += len(next_token_ids) + # Iterate to the next sequence group in a batch. top_logprob_idx += len(seq_ids) return sampled_logprobs, top_logprob_idx, selected_logprobs_idx