[V1][TPU] Do not compile sampling more than needed (#15883)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
01b6113659
commit
bd7599d34a
@ -862,7 +862,9 @@ class TPUModelRunner:
|
|||||||
out = self.model.sample_from_hidden(dummy_hidden,
|
out = self.model.sample_from_hidden(dummy_hidden,
|
||||||
sampling_meta)
|
sampling_meta)
|
||||||
out = out.cpu()
|
out = out.cpu()
|
||||||
if num_reqs_to_sample >= self.max_num_reqs:
|
# Requests can't be more than tokens. But do compile for the
|
||||||
|
# next bigger value in case num_tokens uses bucketed padding.
|
||||||
|
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
|
||||||
break
|
break
|
||||||
# Make sure to compile the `max_num_reqs` upper-limit case
|
# Make sure to compile the `max_num_reqs` upper-limit case
|
||||||
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
|
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user