[misc] hide best_of from engine (#9261)
Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
This commit is contained in:
parent
94bf9ae4e9
commit
cbc2ef5529
@ -70,7 +70,6 @@ EXPECTED_VALUES = {
|
|||||||
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||||
("_count", _NUM_REQUESTS)],
|
("_count", _NUM_REQUESTS)],
|
||||||
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
|
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
|
||||||
"vllm:request_params_best_of": [("_count", _NUM_REQUESTS)],
|
|
||||||
"vllm:prompt_tokens": [("_total",
|
"vllm:prompt_tokens": [("_total",
|
||||||
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
|
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
|
||||||
"vllm:generation_tokens":
|
"vllm:generation_tokens":
|
||||||
@ -151,9 +150,6 @@ EXPECTED_METRICS = [
|
|||||||
"vllm:request_params_n_sum",
|
"vllm:request_params_n_sum",
|
||||||
"vllm:request_params_n_bucket",
|
"vllm:request_params_n_bucket",
|
||||||
"vllm:request_params_n_count",
|
"vllm:request_params_n_count",
|
||||||
"vllm:request_params_best_of_sum",
|
|
||||||
"vllm:request_params_best_of_bucket",
|
|
||||||
"vllm:request_params_best_of_count",
|
|
||||||
"vllm:num_preemptions_total",
|
"vllm:num_preemptions_total",
|
||||||
"vllm:prompt_tokens_total",
|
"vllm:prompt_tokens_total",
|
||||||
"vllm:generation_tokens_total",
|
"vllm:generation_tokens_total",
|
||||||
|
@ -326,7 +326,6 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
|
|||||||
"vllm:e2e_request_latency_seconds",
|
"vllm:e2e_request_latency_seconds",
|
||||||
"vllm:request_prompt_tokens",
|
"vllm:request_prompt_tokens",
|
||||||
"vllm:request_generation_tokens",
|
"vllm:request_generation_tokens",
|
||||||
"vllm:request_params_best_of",
|
|
||||||
"vllm:request_params_n",
|
"vllm:request_params_n",
|
||||||
]
|
]
|
||||||
for metric_name in request_histogram_metrics:
|
for metric_name in request_histogram_metrics:
|
||||||
|
@ -98,8 +98,6 @@ def test_traces(trace_service):
|
|||||||
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
|
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
|
||||||
assert attributes.get(
|
assert attributes.get(
|
||||||
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
|
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
|
|
||||||
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
|
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
|
||||||
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
|
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
|
||||||
outputs[0].prompt_token_ids)
|
outputs[0].prompt_token_ids)
|
||||||
@ -155,8 +153,6 @@ def test_traces_with_detailed_steps(trace_service):
|
|||||||
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
|
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
|
||||||
assert attributes.get(
|
assert attributes.get(
|
||||||
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
|
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
|
|
||||||
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
|
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
|
||||||
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
|
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
|
||||||
outputs[0].prompt_token_ids)
|
outputs[0].prompt_token_ids)
|
||||||
|
@ -1205,7 +1205,7 @@ class Scheduler:
|
|||||||
# async_output_proc is allowed only when we have a single sequence
|
# async_output_proc is allowed only when we have a single sequence
|
||||||
# in the sequence group
|
# in the sequence group
|
||||||
no_single_seq = seq_group.sampling_params is None or (
|
no_single_seq = seq_group.sampling_params is None or (
|
||||||
seq_group.sampling_params.best_of == 1)
|
seq_group.sampling_params.n == 1)
|
||||||
return no_single_seq
|
return no_single_seq
|
||||||
|
|
||||||
def schedule(
|
def schedule(
|
||||||
|
@ -767,7 +767,7 @@ class LLMEngine:
|
|||||||
Details:
|
Details:
|
||||||
- Set arrival_time to the current time if it is None.
|
- Set arrival_time to the current time if it is None.
|
||||||
- Set prompt_token_ids to the encoded prompt if it is None.
|
- Set prompt_token_ids to the encoded prompt if it is None.
|
||||||
- Create `best_of` number of :class:`~vllm.Sequence` objects.
|
- Create `n` number of :class:`~vllm.Sequence` objects.
|
||||||
- Create a :class:`~vllm.SequenceGroup` object
|
- Create a :class:`~vllm.SequenceGroup` object
|
||||||
from the list of :class:`~vllm.Sequence`.
|
from the list of :class:`~vllm.Sequence`.
|
||||||
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
|
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
|
||||||
@ -1242,8 +1242,7 @@ class LLMEngine:
|
|||||||
if seq_group_metadata.do_sample:
|
if seq_group_metadata.do_sample:
|
||||||
assert len(sequence_group_outputs.samples) == 1, (
|
assert len(sequence_group_outputs.samples) == 1, (
|
||||||
"Async output processor expects a single sample"
|
"Async output processor expects a single sample"
|
||||||
" (i.e sampling_params.n == 1 and no "
|
" (i.e sampling_params.n == 1)")
|
||||||
"sampling_params.best_of > 1)")
|
|
||||||
sample = sequence_group_outputs.samples[0]
|
sample = sequence_group_outputs.samples[0]
|
||||||
|
|
||||||
assert len(seq_group.seqs) == 1
|
assert len(seq_group.seqs) == 1
|
||||||
@ -1612,7 +1611,6 @@ class LLMEngine:
|
|||||||
# Metadata
|
# Metadata
|
||||||
num_prompt_tokens_requests: List[int] = []
|
num_prompt_tokens_requests: List[int] = []
|
||||||
num_generation_tokens_requests: List[int] = []
|
num_generation_tokens_requests: List[int] = []
|
||||||
best_of_requests: List[int] = []
|
|
||||||
n_requests: List[int] = []
|
n_requests: List[int] = []
|
||||||
finished_reason_requests: List[str] = []
|
finished_reason_requests: List[str] = []
|
||||||
|
|
||||||
@ -1683,8 +1681,6 @@ class LLMEngine:
|
|||||||
for seq in seq_group.get_finished_seqs()
|
for seq in seq_group.get_finished_seqs()
|
||||||
])
|
])
|
||||||
if seq_group.sampling_params is not None:
|
if seq_group.sampling_params is not None:
|
||||||
best_of_requests.append(
|
|
||||||
seq_group.sampling_params.best_of)
|
|
||||||
n_requests.append(seq_group.sampling_params.n)
|
n_requests.append(seq_group.sampling_params.n)
|
||||||
finished_reason_requests.extend([
|
finished_reason_requests.extend([
|
||||||
SequenceStatus.get_finished_reason(seq.status)
|
SequenceStatus.get_finished_reason(seq.status)
|
||||||
@ -1737,7 +1733,6 @@ class LLMEngine:
|
|||||||
# Metadata
|
# Metadata
|
||||||
num_prompt_tokens_requests=num_prompt_tokens_requests,
|
num_prompt_tokens_requests=num_prompt_tokens_requests,
|
||||||
num_generation_tokens_requests=num_generation_tokens_requests,
|
num_generation_tokens_requests=num_generation_tokens_requests,
|
||||||
best_of_requests=best_of_requests,
|
|
||||||
n_requests=n_requests,
|
n_requests=n_requests,
|
||||||
finished_reason_requests=finished_reason_requests,
|
finished_reason_requests=finished_reason_requests,
|
||||||
)
|
)
|
||||||
@ -1824,8 +1819,6 @@ class LLMEngine:
|
|||||||
seq_group.sampling_params.top_p)
|
seq_group.sampling_params.top_p)
|
||||||
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
|
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
|
||||||
seq_group.sampling_params.max_tokens)
|
seq_group.sampling_params.max_tokens)
|
||||||
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
|
|
||||||
seq_group.sampling_params.best_of)
|
|
||||||
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
|
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
|
||||||
seq_group.sampling_params.n)
|
seq_group.sampling_params.n)
|
||||||
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
|
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
|
||||||
|
@ -134,12 +134,6 @@ class Metrics:
|
|||||||
labelnames=labelnames,
|
labelnames=labelnames,
|
||||||
buckets=build_1_2_5_buckets(max_model_len),
|
buckets=build_1_2_5_buckets(max_model_len),
|
||||||
)
|
)
|
||||||
self.histogram_best_of_request = self._histogram_cls(
|
|
||||||
name="vllm:request_params_best_of",
|
|
||||||
documentation="Histogram of the best_of request parameter.",
|
|
||||||
labelnames=labelnames,
|
|
||||||
buckets=[1, 2, 5, 10, 20],
|
|
||||||
)
|
|
||||||
self.histogram_n_request = self._histogram_cls(
|
self.histogram_n_request = self._histogram_cls(
|
||||||
name="vllm:request_params_n",
|
name="vllm:request_params_n",
|
||||||
documentation="Histogram of the n request parameter.",
|
documentation="Histogram of the n request parameter.",
|
||||||
@ -473,8 +467,6 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.metrics.histogram_num_generation_tokens_request,
|
self.metrics.histogram_num_generation_tokens_request,
|
||||||
stats.num_generation_tokens_requests)
|
stats.num_generation_tokens_requests)
|
||||||
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
|
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
|
||||||
self._log_histogram(self.metrics.histogram_best_of_request,
|
|
||||||
stats.best_of_requests)
|
|
||||||
|
|
||||||
def _log_prometheus_interval(self, prompt_throughput: float,
|
def _log_prometheus_interval(self, prompt_throughput: float,
|
||||||
generation_throughput: float) -> None:
|
generation_throughput: float) -> None:
|
||||||
|
@ -49,7 +49,6 @@ class Stats:
|
|||||||
# Metadata
|
# Metadata
|
||||||
num_prompt_tokens_requests: List[int]
|
num_prompt_tokens_requests: List[int]
|
||||||
num_generation_tokens_requests: List[int]
|
num_generation_tokens_requests: List[int]
|
||||||
best_of_requests: List[int]
|
|
||||||
n_requests: List[int]
|
n_requests: List[int]
|
||||||
finished_reason_requests: List[str]
|
finished_reason_requests: List[str]
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
outputs: SequenceGroupOutput,
|
outputs: SequenceGroupOutput,
|
||||||
is_async: bool) -> None:
|
is_async: bool) -> None:
|
||||||
sampling_params = seq_group.sampling_params
|
sampling_params = seq_group.sampling_params
|
||||||
if sampling_params.best_of == 1:
|
if sampling_params.n == 1:
|
||||||
# only have one output sample
|
# only have one output sample
|
||||||
sample = outputs.samples[0]
|
sample = outputs.samples[0]
|
||||||
# only have one sequence
|
# only have one sequence
|
||||||
|
@ -508,7 +508,7 @@ def _random_sample(
|
|||||||
same as the length of selected_seq_groups. If the corresponding
|
same as the length of selected_seq_groups. If the corresponding
|
||||||
seq_group has do_sample=False, tuple contains ([], [])
|
seq_group has do_sample=False, tuple contains ([], [])
|
||||||
"""
|
"""
|
||||||
# Find the maximum best_of value of the prompt phase requests.
|
# Find the maximum n value of the prompt phase requests.
|
||||||
random_samples = random_samples.cpu()
|
random_samples = random_samples.cpu()
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
results: SampleResultType = []
|
results: SampleResultType = []
|
||||||
@ -523,9 +523,9 @@ def _random_sample(
|
|||||||
num_parent_seqs = len(seq_ids)
|
num_parent_seqs = len(seq_ids)
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
# Prompt phase.
|
# Prompt phase.
|
||||||
parent_ids = [0] * sampling_params.best_of
|
parent_ids = [0] * sampling_params.n
|
||||||
next_token_ids = random_samples[
|
next_token_ids = random_samples[
|
||||||
sample_idx, :sampling_params.best_of].tolist()
|
sample_idx, :sampling_params.n].tolist()
|
||||||
else:
|
else:
|
||||||
# Generation phase.
|
# Generation phase.
|
||||||
parent_ids = list(range(num_parent_seqs))
|
parent_ids = list(range(num_parent_seqs))
|
||||||
@ -570,7 +570,7 @@ def _beam_search_sample(
|
|||||||
is_prompt = seq_group.is_prompt
|
is_prompt = seq_group.is_prompt
|
||||||
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
|
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
|
||||||
num_parent_seqs = len(seq_ids)
|
num_parent_seqs = len(seq_ids)
|
||||||
beam_width = sampling_params.best_of
|
beam_width = sampling_params.n
|
||||||
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
# Prompt phase.
|
# Prompt phase.
|
||||||
@ -797,12 +797,11 @@ def _sample_with_torch(
|
|||||||
greedy_samples)
|
greedy_samples)
|
||||||
|
|
||||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
max_best_of_in_batch = 1
|
max_n_in_batch = 1
|
||||||
for seq_group in seq_groups:
|
for seq_group in seq_groups:
|
||||||
if seq_group.is_prompt:
|
if seq_group.is_prompt:
|
||||||
sampling_params = seq_group.sampling_params
|
sampling_params = seq_group.sampling_params
|
||||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
|
||||||
sampling_params.best_of)
|
|
||||||
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
|
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
|
||||||
seq_groups)
|
seq_groups)
|
||||||
|
|
||||||
@ -812,13 +811,13 @@ def _sample_with_torch(
|
|||||||
probs[long_sample_indices],
|
probs[long_sample_indices],
|
||||||
sampling_tensors.top_ks[long_sample_indices],
|
sampling_tensors.top_ks[long_sample_indices],
|
||||||
sampling_tensors.top_ps[long_sample_indices],
|
sampling_tensors.top_ps[long_sample_indices],
|
||||||
max_best_of_in_batch,
|
max_n_in_batch,
|
||||||
seq_groups_arg,
|
seq_groups_arg,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
multinomial_samples[sampling_type] = _multinomial(
|
multinomial_samples[sampling_type] = _multinomial(
|
||||||
probs[long_sample_indices],
|
probs[long_sample_indices],
|
||||||
max_best_of_in_batch,
|
max_n_in_batch,
|
||||||
seq_groups=seq_groups_arg)
|
seq_groups=seq_groups_arg)
|
||||||
|
|
||||||
if sampled_token_ids_tensor is not None:
|
if sampled_token_ids_tensor is not None:
|
||||||
|
@ -141,7 +141,7 @@ class RequestOutput:
|
|||||||
top_n_seqs = seqs
|
top_n_seqs = seqs
|
||||||
else:
|
else:
|
||||||
# Get the top-n sequences.
|
# Get the top-n sequences.
|
||||||
n = sampling_params.n
|
n = sampling_params._real_n or sampling_params.n
|
||||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||||
top_n_seqs = sorted_seqs[:n]
|
top_n_seqs = sorted_seqs[:n]
|
||||||
|
@ -106,9 +106,8 @@ class SamplingParams(
|
|||||||
n: Number of output sequences to return for the given prompt.
|
n: Number of output sequences to return for the given prompt.
|
||||||
best_of: Number of output sequences that are generated from the prompt.
|
best_of: Number of output sequences that are generated from the prompt.
|
||||||
From these `best_of` sequences, the top `n` sequences are returned.
|
From these `best_of` sequences, the top `n` sequences are returned.
|
||||||
`best_of` must be greater than or equal to `n`. This is treated as
|
`best_of` must be greater than or equal to `n`. By default,
|
||||||
the beam width when `use_beam_search` is True. By default, `best_of`
|
`best_of` is set to `n`.
|
||||||
is set to `n`.
|
|
||||||
presence_penalty: Float that penalizes new tokens based on whether they
|
presence_penalty: Float that penalizes new tokens based on whether they
|
||||||
appear in the generated text so far. Values > 0 encourage the model
|
appear in the generated text so far. Values > 0 encourage the model
|
||||||
to use new tokens, while values < 0 encourage the model to repeat
|
to use new tokens, while values < 0 encourage the model to repeat
|
||||||
@ -173,6 +172,7 @@ class SamplingParams(
|
|||||||
|
|
||||||
n: int = 1
|
n: int = 1
|
||||||
best_of: Optional[int] = None
|
best_of: Optional[int] = None
|
||||||
|
_real_n: Optional[int] = None
|
||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
frequency_penalty: float = 0.0
|
frequency_penalty: float = 0.0
|
||||||
repetition_penalty: float = 1.0
|
repetition_penalty: float = 1.0
|
||||||
@ -282,7 +282,19 @@ class SamplingParams(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.best_of = self.best_of or self.n
|
# how we deal with `best_of``:
|
||||||
|
# if `best_of`` is not set, we default to `n`;
|
||||||
|
# if `best_of`` is set, we set `n`` to `best_of`,
|
||||||
|
# and set `_real_n`` to the original `n`.
|
||||||
|
# when we return the result, we will check
|
||||||
|
# if we need to return `n` or `_real_n` results
|
||||||
|
if self.best_of:
|
||||||
|
if self.best_of < self.n:
|
||||||
|
raise ValueError(
|
||||||
|
f"best_of must be greater than or equal to n, "
|
||||||
|
f"got n={self.n} and best_of={self.best_of}.")
|
||||||
|
self._real_n = self.n
|
||||||
|
self.n = self.best_of
|
||||||
if 0 < self.temperature < _MAX_TEMP:
|
if 0 < self.temperature < _MAX_TEMP:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"temperature %s is less than %s, which may cause numerical "
|
"temperature %s is less than %s, which may cause numerical "
|
||||||
@ -329,12 +341,6 @@ class SamplingParams(
|
|||||||
f"type {type(self.n)}")
|
f"type {type(self.n)}")
|
||||||
if self.n < 1:
|
if self.n < 1:
|
||||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||||
if not isinstance(self.best_of, int):
|
|
||||||
raise ValueError(f"best_of must be an int, but is of "
|
|
||||||
f"type {type(self.best_of)}")
|
|
||||||
if self.best_of < self.n:
|
|
||||||
raise ValueError(f"best_of must be greater than or equal to n, "
|
|
||||||
f"got n={self.n} and best_of={self.best_of}.")
|
|
||||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||||
raise ValueError("presence_penalty must be in [-2, 2], got "
|
raise ValueError("presence_penalty must be in [-2, 2], got "
|
||||||
f"{self.presence_penalty}.")
|
f"{self.presence_penalty}.")
|
||||||
@ -385,7 +391,7 @@ class SamplingParams(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"stop strings are only supported when detokenize is True. "
|
"stop strings are only supported when detokenize is True. "
|
||||||
"Set detokenize=True to use stop.")
|
"Set detokenize=True to use stop.")
|
||||||
if self.best_of != self.n and self.output_kind == (
|
if self.best_of != self._real_n and self.output_kind == (
|
||||||
RequestOutputKind.DELTA):
|
RequestOutputKind.DELTA):
|
||||||
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
||||||
|
|
||||||
@ -393,10 +399,6 @@ class SamplingParams(
|
|||||||
if self.n > 1:
|
if self.n > 1:
|
||||||
raise ValueError("n must be 1 when using greedy sampling, "
|
raise ValueError("n must be 1 when using greedy sampling, "
|
||||||
f"got {self.n}.")
|
f"got {self.n}.")
|
||||||
assert isinstance(self.best_of, int)
|
|
||||||
if self.best_of > 1:
|
|
||||||
raise ValueError("best_of must be 1 when using greedy sampling, "
|
|
||||||
f"got {self.best_of}.")
|
|
||||||
|
|
||||||
def update_from_generation_config(
|
def update_from_generation_config(
|
||||||
self,
|
self,
|
||||||
@ -453,7 +455,6 @@ class SamplingParams(
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"SamplingParams(n={self.n}, "
|
f"SamplingParams(n={self.n}, "
|
||||||
f"best_of={self.best_of}, "
|
|
||||||
f"presence_penalty={self.presence_penalty}, "
|
f"presence_penalty={self.presence_penalty}, "
|
||||||
f"frequency_penalty={self.frequency_penalty}, "
|
f"frequency_penalty={self.frequency_penalty}, "
|
||||||
f"repetition_penalty={self.repetition_penalty}, "
|
f"repetition_penalty={self.repetition_penalty}, "
|
||||||
|
@ -803,14 +803,14 @@ class SequenceGroup:
|
|||||||
"""The maximum number of sequences running in parallel in the remaining
|
"""The maximum number of sequences running in parallel in the remaining
|
||||||
lifetime of the request."""
|
lifetime of the request."""
|
||||||
if self.sampling_params:
|
if self.sampling_params:
|
||||||
best_of = self.sampling_params.best_of
|
n = self.sampling_params.n
|
||||||
assert isinstance(best_of, int)
|
assert isinstance(n, int)
|
||||||
if best_of > self.num_seqs():
|
if n > self.num_seqs():
|
||||||
# At prompt stage, the sequence group is not yet filled up
|
# At prompt stage, the sequence group is not yet filled up
|
||||||
# and only have one sequence running. However, in the
|
# and only have one sequence running. However, in the
|
||||||
# generation stage, we will have `best_of` sequences
|
# generation stage, we will have `n` sequences
|
||||||
# running.
|
# running.
|
||||||
return best_of
|
return n
|
||||||
# At sampling stages, return the number of actual sequences
|
# At sampling stages, return the number of actual sequences
|
||||||
# that are not finished yet.
|
# that are not finished yet.
|
||||||
return self.num_unfinished_seqs()
|
return self.num_unfinished_seqs()
|
||||||
|
@ -96,7 +96,6 @@ class SpanAttributes(BaseSpanAttributes):
|
|||||||
# The following span attribute names are added here because they are missing
|
# The following span attribute names are added here because they are missing
|
||||||
# from the Semantic Conventions for LLM.
|
# from the Semantic Conventions for LLM.
|
||||||
LLM_REQUEST_ID = "gen_ai.request.id"
|
LLM_REQUEST_ID = "gen_ai.request.id"
|
||||||
LLM_REQUEST_BEST_OF = "gen_ai.request.best_of"
|
|
||||||
LLM_REQUEST_N = "gen_ai.request.n"
|
LLM_REQUEST_N = "gen_ai.request.n"
|
||||||
LLM_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
|
LLM_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
|
||||||
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
|
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
|
||||||
|
@ -49,7 +49,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
|
|||||||
t: torch.Tensor
|
t: torch.Tensor
|
||||||
p: torch.Tensor
|
p: torch.Tensor
|
||||||
num_samples: int
|
num_samples: int
|
||||||
best_of: List[int]
|
n: List[int]
|
||||||
seq_groups: List[List[int]]
|
seq_groups: List[List[int]]
|
||||||
is_first_multi_step: bool = True
|
is_first_multi_step: bool = True
|
||||||
is_last_step: bool = True
|
is_last_step: bool = True
|
||||||
@ -65,7 +65,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
|
|||||||
"t": self.t,
|
"t": self.t,
|
||||||
"p": self.p,
|
"p": self.p,
|
||||||
"num_samples": self.num_samples,
|
"num_samples": self.num_samples,
|
||||||
"best_of": self.best_of,
|
"n": self.n,
|
||||||
"seq_groups": self.seq_groups,
|
"seq_groups": self.seq_groups,
|
||||||
"is_first_multi_step": self.is_first_multi_step,
|
"is_first_multi_step": self.is_first_multi_step,
|
||||||
"is_last_step": self.is_last_step,
|
"is_last_step": self.is_last_step,
|
||||||
@ -435,7 +435,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
t = []
|
t = []
|
||||||
p = []
|
p = []
|
||||||
best_of = []
|
n = []
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
t.append(sampling_params.temperature)
|
t.append(sampling_params.temperature)
|
||||||
@ -448,11 +448,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Top-k sampling is currently disabled for the TPU backend "
|
"Top-k sampling is currently disabled for the TPU backend "
|
||||||
"due to performance issues.")
|
"due to performance issues.")
|
||||||
if sampling_params.best_of > _MAX_NUM_SAMPLES:
|
if sampling_params.n > _MAX_NUM_SAMPLES:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
|
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
|
||||||
"backend.")
|
"backend.")
|
||||||
best_of.append(sampling_params.best_of)
|
n.append(sampling_params.n)
|
||||||
if sampling_params.logprobs is not None:
|
if sampling_params.logprobs is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"logprobs is not currently supported by the TPU backend.")
|
"logprobs is not currently supported by the TPU backend.")
|
||||||
@ -465,7 +465,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
num_seqs = len(seq_group_metadata.seq_data)
|
num_seqs = len(seq_group_metadata.seq_data)
|
||||||
t += [t[-1]] * (num_seqs - 1)
|
t += [t[-1]] * (num_seqs - 1)
|
||||||
p += [p[-1]] * (num_seqs - 1)
|
p += [p[-1]] * (num_seqs - 1)
|
||||||
best_of += [best_of[-1]] * (num_seqs - 1)
|
n += [n[-1]] * (num_seqs - 1)
|
||||||
|
|
||||||
num_paddings = padded_batch_size - len(t)
|
num_paddings = padded_batch_size - len(t)
|
||||||
t += [1.0] * num_paddings
|
t += [1.0] * num_paddings
|
||||||
@ -473,7 +473,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
|
|
||||||
t = torch.tensor(t, dtype=torch.float32, device="cpu")
|
t = torch.tensor(t, dtype=torch.float32, device="cpu")
|
||||||
p = torch.tensor(p, dtype=torch.float32, device="cpu")
|
p = torch.tensor(p, dtype=torch.float32, device="cpu")
|
||||||
return t, p, best_of
|
return t, p, n
|
||||||
|
|
||||||
def prepare_model_input(
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
@ -493,7 +493,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||||
input_tokens, input_positions, attn_metadata, input_lens = inputs
|
input_tokens, input_positions, attn_metadata, input_lens = inputs
|
||||||
padded_batch_size = input_tokens.shape[0]
|
padded_batch_size = input_tokens.shape[0]
|
||||||
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
|
t, p, n = self._prepare_sample(seq_group_metadata_list,
|
||||||
padded_batch_size)
|
padded_batch_size)
|
||||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||||
|
|
||||||
@ -502,8 +502,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
for metadata in seq_group_metadata_list
|
for metadata in seq_group_metadata_list
|
||||||
]
|
]
|
||||||
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
|
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
|
||||||
input_lens, t, p, num_samples, best_of,
|
input_lens, t, p, num_samples, n, seq_groups)
|
||||||
seq_groups)
|
|
||||||
|
|
||||||
def make_model_input_from_broadcasted_tensor_dict(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
|
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
|
||||||
@ -609,7 +608,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
assert len(seq_ids) == 1
|
assert len(seq_ids) == 1
|
||||||
seq_id = seq_ids[0]
|
seq_id = seq_ids[0]
|
||||||
seq_outputs = []
|
seq_outputs = []
|
||||||
for j in range(model_input.best_of[i]):
|
for j in range(model_input.n[i]):
|
||||||
next_token_id = next_token_ids[i][j]
|
next_token_id = next_token_ids[i][j]
|
||||||
seq_outputs.append(
|
seq_outputs.append(
|
||||||
SequenceOutput(seq_id, next_token_id,
|
SequenceOutput(seq_id, next_token_id,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user