[TPU][V1] Fix exponential padding when max-num-batched-tokens
is not a power of 2 (#16596)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
aa29841ede
commit
b3f2fddd17
@ -299,6 +299,18 @@ def test_get_paddings():
|
|||||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
|
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
|
||||||
padding_gap)
|
padding_gap)
|
||||||
assert actual_paddings == expected_paddings
|
assert actual_paddings == expected_paddings
|
||||||
|
# Exponential padding.
|
||||||
|
max_token_size, padding_gap = 1024, 0
|
||||||
|
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
|
||||||
|
padding_gap)
|
||||||
|
assert actual_paddings == expected_paddings
|
||||||
|
# Exponential padding with max_token_size not a power of two.
|
||||||
|
max_token_size = 317
|
||||||
|
expected_paddings = [16, 32, 64, 128, 256, 512]
|
||||||
|
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
|
||||||
|
padding_gap)
|
||||||
|
assert actual_paddings == expected_paddings
|
||||||
|
|
||||||
|
|
||||||
def test_get_padded_token_len():
|
def test_get_padded_token_len():
|
||||||
|
@ -1040,9 +1040,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int,
|
|||||||
|
|
||||||
if padding_gap == 0:
|
if padding_gap == 0:
|
||||||
logger.info("Using exponential token paddings:")
|
logger.info("Using exponential token paddings:")
|
||||||
while num <= max_token_size:
|
while True:
|
||||||
logger.info(" %d", num)
|
logger.info(" %d", num)
|
||||||
paddings.append(num)
|
paddings.append(num)
|
||||||
|
if num >= max_token_size:
|
||||||
|
break
|
||||||
num *= 2
|
num *= 2
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user