[misc] hide best_of from engine (#9261)

Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
This commit is contained in:
youkaichao 2024-10-10 21:30:44 -07:00 committed by GitHub
parent 94bf9ae4e9
commit cbc2ef5529
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 46 additions and 73 deletions

View File

@ -70,7 +70,6 @@ EXPECTED_VALUES = {
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)], ("_count", _NUM_REQUESTS)],
"vllm:request_params_n": [("_count", _NUM_REQUESTS)], "vllm:request_params_n": [("_count", _NUM_REQUESTS)],
"vllm:request_params_best_of": [("_count", _NUM_REQUESTS)],
"vllm:prompt_tokens": [("_total", "vllm:prompt_tokens": [("_total",
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:generation_tokens": "vllm:generation_tokens":
@ -151,9 +150,6 @@ EXPECTED_METRICS = [
"vllm:request_params_n_sum", "vllm:request_params_n_sum",
"vllm:request_params_n_bucket", "vllm:request_params_n_bucket",
"vllm:request_params_n_count", "vllm:request_params_n_count",
"vllm:request_params_best_of_sum",
"vllm:request_params_best_of_bucket",
"vllm:request_params_best_of_count",
"vllm:num_preemptions_total", "vllm:num_preemptions_total",
"vllm:prompt_tokens_total", "vllm:prompt_tokens_total",
"vllm:generation_tokens_total", "vllm:generation_tokens_total",

View File

@ -326,7 +326,6 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
"vllm:e2e_request_latency_seconds", "vllm:e2e_request_latency_seconds",
"vllm:request_prompt_tokens", "vllm:request_prompt_tokens",
"vllm:request_generation_tokens", "vllm:request_generation_tokens",
"vllm:request_params_best_of",
"vllm:request_params_n", "vllm:request_params_n",
] ]
for metric_name in request_histogram_metrics: for metric_name in request_histogram_metrics:

View File

@ -98,8 +98,6 @@ def test_traces(trace_service):
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get( assert attributes.get(
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len( assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids) outputs[0].prompt_token_ids)
@ -155,8 +153,6 @@ def test_traces_with_detailed_steps(trace_service):
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get( assert attributes.get(
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len( assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids) outputs[0].prompt_token_ids)

View File

@ -1205,7 +1205,7 @@ class Scheduler:
# async_output_proc is allowed only when we have a single sequence # async_output_proc is allowed only when we have a single sequence
# in the sequence group # in the sequence group
no_single_seq = seq_group.sampling_params is None or ( no_single_seq = seq_group.sampling_params is None or (
seq_group.sampling_params.best_of == 1) seq_group.sampling_params.n == 1)
return no_single_seq return no_single_seq
def schedule( def schedule(

View File

@ -767,7 +767,7 @@ class LLMEngine:
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None. - Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects. - Create `n` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object - Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`. from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler. - Add the :class:`~vllm.SequenceGroup` object to the scheduler.
@ -1242,8 +1242,7 @@ class LLMEngine:
if seq_group_metadata.do_sample: if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, ( assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample" "Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no " " (i.e sampling_params.n == 1)")
"sampling_params.best_of > 1)")
sample = sequence_group_outputs.samples[0] sample = sequence_group_outputs.samples[0]
assert len(seq_group.seqs) == 1 assert len(seq_group.seqs) == 1
@ -1612,7 +1611,6 @@ class LLMEngine:
# Metadata # Metadata
num_prompt_tokens_requests: List[int] = [] num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = [] num_generation_tokens_requests: List[int] = []
best_of_requests: List[int] = []
n_requests: List[int] = [] n_requests: List[int] = []
finished_reason_requests: List[str] = [] finished_reason_requests: List[str] = []
@ -1683,8 +1681,6 @@ class LLMEngine:
for seq in seq_group.get_finished_seqs() for seq in seq_group.get_finished_seqs()
]) ])
if seq_group.sampling_params is not None: if seq_group.sampling_params is not None:
best_of_requests.append(
seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n) n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([ finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status) SequenceStatus.get_finished_reason(seq.status)
@ -1737,7 +1733,6 @@ class LLMEngine:
# Metadata # Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests, num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests, num_generation_tokens_requests=num_generation_tokens_requests,
best_of_requests=best_of_requests,
n_requests=n_requests, n_requests=n_requests,
finished_reason_requests=finished_reason_requests, finished_reason_requests=finished_reason_requests,
) )
@ -1824,8 +1819,6 @@ class LLMEngine:
seq_group.sampling_params.top_p) seq_group.sampling_params.top_p)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS, seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
seq_group.sampling_params.max_tokens) seq_group.sampling_params.max_tokens)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
seq_group.sampling_params.best_of)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N, seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
seq_group.sampling_params.n) seq_group.sampling_params.n)
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES, seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,

View File

@ -134,12 +134,6 @@ class Metrics:
labelnames=labelnames, labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
) )
self.histogram_best_of_request = self._histogram_cls(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self._histogram_cls( self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n", name="vllm:request_params_n",
documentation="Histogram of the n request parameter.", documentation="Histogram of the n request parameter.",
@ -473,8 +467,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.metrics.histogram_num_generation_tokens_request, self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests) stats.num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests)
def _log_prometheus_interval(self, prompt_throughput: float, def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None: generation_throughput: float) -> None:

View File

@ -49,7 +49,6 @@ class Stats:
# Metadata # Metadata
num_prompt_tokens_requests: List[int] num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int] num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int] n_requests: List[int]
finished_reason_requests: List[str] finished_reason_requests: List[str]

View File

@ -112,7 +112,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
outputs: SequenceGroupOutput, outputs: SequenceGroupOutput,
is_async: bool) -> None: is_async: bool) -> None:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
if sampling_params.best_of == 1: if sampling_params.n == 1:
# only have one output sample # only have one output sample
sample = outputs.samples[0] sample = outputs.samples[0]
# only have one sequence # only have one sequence

View File

@ -508,7 +508,7 @@ def _random_sample(
same as the length of selected_seq_groups. If the corresponding same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
# Find the maximum best_of value of the prompt phase requests. # Find the maximum n value of the prompt phase requests.
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
@ -523,9 +523,9 @@ def _random_sample(
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
parent_ids = [0] * sampling_params.best_of parent_ids = [0] * sampling_params.n
next_token_ids = random_samples[ next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist() sample_idx, :sampling_params.n].tolist()
else: else:
# Generation phase. # Generation phase.
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
@ -570,7 +570,7 @@ def _beam_search_sample(
is_prompt = seq_group.is_prompt is_prompt = seq_group.is_prompt
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of beam_width = sampling_params.n
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
@ -797,12 +797,11 @@ def _sample_with_torch(
greedy_samples) greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1 max_n_in_batch = 1
for seq_group in seq_groups: for seq_group in seq_groups:
if seq_group.is_prompt: if seq_group.is_prompt:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch, max_n_in_batch = max(max_n_in_batch, sampling_params.n)
sampling_params.best_of)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups) seq_groups)
@ -812,13 +811,13 @@ def _sample_with_torch(
probs[long_sample_indices], probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices], sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices], sampling_tensors.top_ps[long_sample_indices],
max_best_of_in_batch, max_n_in_batch,
seq_groups_arg, seq_groups_arg,
) )
else: else:
multinomial_samples[sampling_type] = _multinomial( multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices], probs[long_sample_indices],
max_best_of_in_batch, max_n_in_batch,
seq_groups=seq_groups_arg) seq_groups=seq_groups_arg)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:

View File

@ -141,7 +141,7 @@ class RequestOutput:
top_n_seqs = seqs top_n_seqs = seqs
else: else:
# Get the top-n sequences. # Get the top-n sequences.
n = sampling_params.n n = sampling_params._real_n or sampling_params.n
sorting_key = lambda seq: seq.get_cumulative_logprob() sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n] top_n_seqs = sorted_seqs[:n]

View File

@ -106,9 +106,8 @@ class SamplingParams(
n: Number of output sequences to return for the given prompt. n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt. best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned. From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. This is treated as `best_of` must be greater than or equal to `n`. By default,
the beam width when `use_beam_search` is True. By default, `best_of` `best_of` is set to `n`.
is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat to use new tokens, while values < 0 encourage the model to repeat
@ -173,6 +172,7 @@ class SamplingParams(
n: int = 1 n: int = 1
best_of: Optional[int] = None best_of: Optional[int] = None
_real_n: Optional[int] = None
presence_penalty: float = 0.0 presence_penalty: float = 0.0
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
@ -282,7 +282,19 @@ class SamplingParams(
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.best_of = self.best_of or self.n # how we deal with `best_of``:
# if `best_of`` is not set, we default to `n`;
# if `best_of`` is set, we set `n`` to `best_of`,
# and set `_real_n`` to the original `n`.
# when we return the result, we will check
# if we need to return `n` or `_real_n` results
if self.best_of:
if self.best_of < self.n:
raise ValueError(
f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
self._real_n = self.n
self.n = self.best_of
if 0 < self.temperature < _MAX_TEMP: if 0 < self.temperature < _MAX_TEMP:
logger.warning( logger.warning(
"temperature %s is less than %s, which may cause numerical " "temperature %s is less than %s, which may cause numerical "
@ -329,12 +341,6 @@ class SamplingParams(
f"type {type(self.n)}") f"type {type(self.n)}")
if self.n < 1: if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.") raise ValueError(f"n must be at least 1, got {self.n}.")
if not isinstance(self.best_of, int):
raise ValueError(f"best_of must be an int, but is of "
f"type {type(self.best_of)}")
if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
if not -2.0 <= self.presence_penalty <= 2.0: if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError("presence_penalty must be in [-2, 2], got " raise ValueError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.") f"{self.presence_penalty}.")
@ -385,7 +391,7 @@ class SamplingParams(
raise ValueError( raise ValueError(
"stop strings are only supported when detokenize is True. " "stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.") "Set detokenize=True to use stop.")
if self.best_of != self.n and self.output_kind == ( if self.best_of != self._real_n and self.output_kind == (
RequestOutputKind.DELTA): RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA") raise ValueError("best_of must equal n to use output_kind=DELTA")
@ -393,10 +399,6 @@ class SamplingParams(
if self.n > 1: if self.n > 1:
raise ValueError("n must be 1 when using greedy sampling, " raise ValueError("n must be 1 when using greedy sampling, "
f"got {self.n}.") f"got {self.n}.")
assert isinstance(self.best_of, int)
if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling, "
f"got {self.best_of}.")
def update_from_generation_config( def update_from_generation_config(
self, self,
@ -453,7 +455,6 @@ class SamplingParams(
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"SamplingParams(n={self.n}, " f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, " f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, " f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, " f"repetition_penalty={self.repetition_penalty}, "

View File

@ -803,14 +803,14 @@ class SequenceGroup:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
lifetime of the request.""" lifetime of the request."""
if self.sampling_params: if self.sampling_params:
best_of = self.sampling_params.best_of n = self.sampling_params.n
assert isinstance(best_of, int) assert isinstance(n, int)
if best_of > self.num_seqs(): if n > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up # At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the # and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences # generation stage, we will have `n` sequences
# running. # running.
return best_of return n
# At sampling stages, return the number of actual sequences # At sampling stages, return the number of actual sequences
# that are not finished yet. # that are not finished yet.
return self.num_unfinished_seqs() return self.num_unfinished_seqs()

View File

@ -96,7 +96,6 @@ class SpanAttributes(BaseSpanAttributes):
# The following span attribute names are added here because they are missing # The following span attribute names are added here because they are missing
# from the Semantic Conventions for LLM. # from the Semantic Conventions for LLM.
LLM_REQUEST_ID = "gen_ai.request.id" LLM_REQUEST_ID = "gen_ai.request.id"
LLM_REQUEST_BEST_OF = "gen_ai.request.best_of"
LLM_REQUEST_N = "gen_ai.request.n" LLM_REQUEST_N = "gen_ai.request.n"
LLM_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences" LLM_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue" LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"

View File

@ -49,7 +49,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
t: torch.Tensor t: torch.Tensor
p: torch.Tensor p: torch.Tensor
num_samples: int num_samples: int
best_of: List[int] n: List[int]
seq_groups: List[List[int]] seq_groups: List[List[int]]
is_first_multi_step: bool = True is_first_multi_step: bool = True
is_last_step: bool = True is_last_step: bool = True
@ -65,7 +65,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
"t": self.t, "t": self.t,
"p": self.p, "p": self.p,
"num_samples": self.num_samples, "num_samples": self.num_samples,
"best_of": self.best_of, "n": self.n,
"seq_groups": self.seq_groups, "seq_groups": self.seq_groups,
"is_first_multi_step": self.is_first_multi_step, "is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step, "is_last_step": self.is_last_step,
@ -435,7 +435,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
t = [] t = []
p = [] p = []
best_of = [] n = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params sampling_params = seq_group_metadata.sampling_params
t.append(sampling_params.temperature) t.append(sampling_params.temperature)
@ -448,11 +448,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
raise NotImplementedError( raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend " "Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.") "due to performance issues.")
if sampling_params.best_of > _MAX_NUM_SAMPLES: if sampling_params.n > _MAX_NUM_SAMPLES:
raise NotImplementedError( raise NotImplementedError(
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.") "backend.")
best_of.append(sampling_params.best_of) n.append(sampling_params.n)
if sampling_params.logprobs is not None: if sampling_params.logprobs is not None:
raise NotImplementedError( raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.") "logprobs is not currently supported by the TPU backend.")
@ -465,7 +465,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_seqs = len(seq_group_metadata.seq_data) num_seqs = len(seq_group_metadata.seq_data)
t += [t[-1]] * (num_seqs - 1) t += [t[-1]] * (num_seqs - 1)
p += [p[-1]] * (num_seqs - 1) p += [p[-1]] * (num_seqs - 1)
best_of += [best_of[-1]] * (num_seqs - 1) n += [n[-1]] * (num_seqs - 1)
num_paddings = padded_batch_size - len(t) num_paddings = padded_batch_size - len(t)
t += [1.0] * num_paddings t += [1.0] * num_paddings
@ -473,7 +473,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
t = torch.tensor(t, dtype=torch.float32, device="cpu") t = torch.tensor(t, dtype=torch.float32, device="cpu")
p = torch.tensor(p, dtype=torch.float32, device="cpu") p = torch.tensor(p, dtype=torch.float32, device="cpu")
return t, p, best_of return t, p, n
def prepare_model_input( def prepare_model_input(
self, self,
@ -493,7 +493,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
inputs = self._prepare_decode(seq_group_metadata_list) inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, input_lens = inputs input_tokens, input_positions, attn_metadata, input_lens = inputs
padded_batch_size = input_tokens.shape[0] padded_batch_size = input_tokens.shape[0]
t, p, best_of = self._prepare_sample(seq_group_metadata_list, t, p, n = self._prepare_sample(seq_group_metadata_list,
padded_batch_size) padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
@ -502,8 +502,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
for metadata in seq_group_metadata_list for metadata in seq_group_metadata_list
] ]
return ModelInputForTPU(input_tokens, input_positions, attn_metadata, return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
input_lens, t, p, num_samples, best_of, input_lens, t, p, num_samples, n, seq_groups)
seq_groups)
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
@ -609,7 +608,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
assert len(seq_ids) == 1 assert len(seq_ids) == 1
seq_id = seq_ids[0] seq_id = seq_ids[0]
seq_outputs = [] seq_outputs = []
for j in range(model_input.best_of[i]): for j in range(model_input.n[i]):
next_token_id = next_token_ids[i][j] next_token_id = next_token_ids[i][j]
seq_outputs.append( seq_outputs.append(
SequenceOutput(seq_id, next_token_id, SequenceOutput(seq_id, next_token_id,