[V1][TPU] Do not compile sampling more than needed (#15883)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-04-03 03:36:01 +02:00 committed by GitHub
parent 01b6113659
commit bd7599d34a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -862,7 +862,9 @@ class TPUModelRunner:
out = self.model.sample_from_hidden(dummy_hidden,
sampling_meta)
out = out.cpu()
if num_reqs_to_sample >= self.max_num_reqs:
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
break
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(