[TPU][V1] Refine tpu_model_runner to mitigate future recompilation issues (#16275)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-04-09 17:51:51 -07:00 committed by GitHub
parent 1bff42c4b7
commit a454748544
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 166 additions and 125 deletions

View File

@ -44,7 +44,7 @@ def test_tpu_compilation():
assert generated_text.startswith(answer) assert generated_text.startswith(answer)
compiled_codes = sorted( compiled_codes = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py")))
for i, compiled_code in enumerate(compiled_codes): for i, compiled_code in enumerate(compiled_codes):
print("{} file: {}".format(i + 1, compiled_code)) print("{} file: {}".format(i + 1, compiled_code))
@ -52,15 +52,21 @@ def test_tpu_compilation():
# We should only trigger Dynamo compilation 2 times: # We should only trigger Dynamo compilation 2 times:
# 1. Forward pass without kv_caches # 1. Forward pass without kv_caches
# 2. Forward pass with kv_caches # 2. Forward pass with kv_caches
# Check we have 4 compiled codes # Check we have 2 compiled codes
assert len(compiled_codes) == 2 assert len(compiled_codes) == 2
kv_cache_prefix = "kv_cache" kv_cache_prefix = "kv_cache"
attn_prefix = "ragged_paged_attention" attn_prefix = "ragged_paged_attention"
def extract_compiled_index(s):
parts = s.replace(".", "_").split("_")
numbers = [int(part) for part in parts if part.isdigit()]
return numbers[0]
# Check all the compilations are as expected # Check all the compilations are as expected
compiled_fns = sorted( compiled_fns = sorted(glob.glob(
glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py"))) os.path.join(temp_dir, "__compiled_fn*Captured*.py")),
key=lambda s: extract_compiled_index(s))
for i, compiled_fn in enumerate(compiled_fns): for i, compiled_fn in enumerate(compiled_fns):
print("{} file: {}".format(i + 1, compiled_fn)) print("{} file: {}".format(i + 1, compiled_fn))

View File

@ -7,9 +7,9 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput) SchedulerOutput)
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner, from vllm.v1.worker.tpu_model_runner import (
_get_padded_token_len, TPUModelRunner, _get_padded_num_reqs_with_upper_limit,
_get_paddings) _get_padded_token_len, _get_req_paddings, _get_token_paddings)
# Mock torch_xla module since it may not be available in the test environments # Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher = mock.patch.dict( torch_xla_patcher = mock.patch.dict(
@ -296,16 +296,29 @@ def test_update_states_request_unscheduled(model_runner):
def test_get_paddings(): def test_get_paddings():
min_token_size, max_token_size, padding_gap = 16, 512, 64 min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_paddings(min_token_size, max_token_size, actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap) padding_gap)
assert actual_paddings == expected_paddings assert actual_paddings == expected_paddings
def test_get_padded_token_len(): def test_get_padded_token_len():
min_token_size, max_token_size, padding_gap = 16, 512, 64 min_token_size, max_token_size, padding_gap = 16, 512, 64
paddings = _get_paddings(min_token_size, max_token_size, padding_gap) paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert _get_padded_token_len(paddings, 1) == 16 assert _get_padded_token_len(paddings, 1) == 16
assert _get_padded_token_len(paddings, 16) == 16 assert _get_padded_token_len(paddings, 16) == 16
assert _get_padded_token_len(paddings, 20) == 32 assert _get_padded_token_len(paddings, 20) == 32
assert _get_padded_token_len(paddings, 300) == 320 assert _get_padded_token_len(paddings, 300) == 320
assert _get_padded_token_len(paddings, 512) == 512 assert _get_padded_token_len(paddings, 512) == 512
def test_get_padded_num_reqs_with_upper_limit():
assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
def test_get_req_paddings():
assert _get_req_paddings(1, 32) == [8, 16, 32]
assert _get_req_paddings(8, 32) == [8, 16, 32]
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]

View File

@ -3,7 +3,6 @@ from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import torch import torch
import torch_xla.core.xla_model as xm
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
@ -24,15 +23,15 @@ class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata # This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals # on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones. # are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor temperature: torch.Tensor = None
min_p: torch.Tensor min_p: torch.Tensor = None
# Still too slow on forward_native! # Still too slow on forward_native!
top_k: torch.Tensor = None top_k: torch.Tensor = None
top_p: torch.Tensor = None top_p: torch.Tensor = None
# Greedy sampling flag for compiling single xla graph. # Greedy sampling flag for compiling single xla graph.
all_greedy: torch.Tensor = None all_greedy: bool = True
# Generator not supported by xla # Generator not supported by xla
generators: dict[int, generators: dict[int,
@ -57,64 +56,58 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask = None allowed_token_ids_mask = None
bad_words_token_ids = None bad_words_token_ids = None
indices_do_sample: torch.Tensor = None
@classmethod @classmethod
def from_input_batch( def from_input_batch(
cls, input_batch: InputBatch, cls,
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata": input_batch: InputBatch,
padded_num_reqs: int,
xla_device: torch.device,
generate_params_if_all_greedy: bool = False
) -> "TPUSupportedSamplingMetadata":
""" """
Copy sampling tensors slices from `input_batch` to on device tensors. Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it `InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It ops to CPU and produces tensors of fixed `padded_num_reqs` size.
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
`indices_do_sample` contains the indices to be fed to the Sampler, Args:
normally one per request, here padded to the closest pre-compiled shape input_batch: The input batch containing sampling parameters.
We expect sampling params tensors to be padded to the same fixed shape. padded_num_reqs: The padded number of requests.
xla_device: The XLA device.
Eg. 3 requests, tensors padded to 4 generate_params_if_all_greedy: If True, generate sampling parameters
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0] even if all requests are greedy. this is useful for cases where
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0] we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
""" """
num_reqs = input_batch.num_reqs # Early return to avoid unnecessary cpu to tpu copy
padded_num_reqs = len(indices_do_sample) if (input_batch.all_greedy is True
and generate_params_if_all_greedy is False):
return cls(all_greedy=True)
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, num_reqs = input_batch.num_reqs
fill_val) -> torch.Tensor:
# Copy slice from CPU to corresponding TPU pre-allocated tensor. def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
# Pad value is the default one. # Pad value is the default one.
cpu_tensor[num_reqs:padded_num_reqs] = fill_val cpu_tensor[num_reqs:padded_num_reqs] = fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be fill_slice(input_batch.temperature_cpu_tensor,
# consistent. We can't have flags to skip copies or we'll end up
# recompiling.
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
DEFAULT_SAMPLING_PARAMS["temperature"]) DEFAULT_SAMPLING_PARAMS["temperature"])
# TODO Temporarily disabled until sampling options are enabled # TODO Temporarily disabled until sampling options are enabled
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p) # fill_slice(input_batch.top_p_cpu_tensor)
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k) # fill_slice(input_batch.top_k_cpu_tensor)
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p, fill_slice(input_batch.min_p_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["min_p"]) DEFAULT_SAMPLING_PARAMS["min_p"])
xm.mark_step()
xm.wait_device_ops()
# Slice persistent device tensors to a fixed pre-compiled padded shape. # Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls( return cls(
temperature=input_batch.temperature[:padded_num_reqs], temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
# Scalar tensor for xla-friendly tracing. to(xla_device),
all_greedy=torch.tensor(input_batch.all_greedy, all_greedy=input_batch.all_greedy,
dtype=torch.bool,
device=input_batch.device),
# TODO enable more and avoid returning None values # TODO enable more and avoid returning None values
top_p=None, # input_batch.top_p[:padded_num_reqs], top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs], top_k=None, # input_batch.top_k[:padded_num_reqs],
min_p=input_batch.min_p[:padded_num_reqs], min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
generators=input_batch.generators, xla_device),
indices_do_sample=indices_do_sample) generators=input_batch.generators)

View File

@ -32,7 +32,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, SlidingWindowSpec) KVCacheSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput, SamplerOutput) ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
@ -177,10 +177,12 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1]. # Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens # Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_tokens_paddings = _get_paddings( self.num_tokens_paddings = _get_token_paddings(
min_token_size=16, min_token_size=16,
max_token_size=self.max_num_tokens, max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
self.num_reqs_paddings = _get_req_paddings(
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
def _update_num_xla_graphs(self, case_str): def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager check_comp = self.check_recompilation and not self.enforce_eager
@ -508,7 +510,7 @@ class TPUModelRunner:
# Padded to avoid recompiling when `num_reqs` varies. # Padded to avoid recompiling when `num_reqs` varies.
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device) logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices return attn_metadata, logits_indices, padded_num_reqs
def _scatter_placeholders( def _scatter_placeholders(
self, self,
@ -663,7 +665,8 @@ class TPUModelRunner:
mm_embeds = [] mm_embeds = []
# Prepare inputs # Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
scheduler_output)
if self.is_multimodal_model: if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
@ -682,11 +685,6 @@ class TPUModelRunner:
input_ids = self.input_ids input_ids = self.input_ids
inputs_embeds = None inputs_embeds = None
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, logits_indices)
# Run the decoder # Run the decoder
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model( hidden_states = self.model(
@ -694,6 +692,10 @@ class TPUModelRunner:
positions=self.position_ids, positions=self.position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
hidden_states = self.select_hidden_states(hidden_states,
logits_indices)
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, padded_num_reqs, self.device)
selected_token_ids = self.sample_from_hidden(hidden_states, selected_token_ids = self.sample_from_hidden(hidden_states,
tpu_sampling_metadata) tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph. # Remove padding on cpu and keep dynamic op outside of xla graph.
@ -857,60 +859,78 @@ class TPUModelRunner:
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype self._hidden_states_dtype = out.dtype
def capture_model(self) -> None: def _precompile_backbone(self) -> None:
"""Compile the model."""
logger.info("Compiling the model with different input shapes.") logger.info("Compiling the model with different input shapes.")
start = time.perf_counter() start = time.perf_counter()
for num_tokens in self.num_tokens_paddings: for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens) logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(num_tokens) self._dummy_run(num_tokens)
xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start) logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("model") self._update_num_xla_graphs("model backbone")
def _precompile_select_hidden_states(self) -> None:
# Compile hidden state selection function for bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
logger.info(
"Compiling select_hidden_states with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
for num_tokens in self.num_tokens_paddings:
dummy_hidden = torch.zeros((num_tokens, hsize),
device=self.device,
dtype=self._hidden_states_dtype)
torch._dynamo.mark_dynamic(dummy_hidden, 0)
for num_reqs in self.num_reqs_paddings:
indices = torch.zeros(num_reqs,
dtype=torch.int32,
device=self.device)
torch._dynamo.mark_dynamic(indices, 0)
self.select_hidden_states(dummy_hidden, indices)
logger.info(" -- num_tokens: %d", num_tokens)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("select_hidden_states")
def _precompile_sample_from_hidden(self) -> None:
logger.info("Compiling sampling with different input shapes.") logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter() start = time.perf_counter()
hsize = self.model_config.get_hidden_size() hsize = self.model_config.get_hidden_size()
device = self.device for num_reqs in self.num_reqs_paddings:
# Compile sampling step for different model+sampler outputs in bucketed dummy_hidden = torch.zeros((num_reqs, hsize),
# n_tokens x max_num_reqs. Graph is really small so this is fine. device=self.device,
for num_tokens in self.num_tokens_paddings:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dtype=self._hidden_states_dtype) dtype=self._hidden_states_dtype)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`] # The first dimension of dummy_hidden cannot be mark_dynamic because
while True: # some operations in the sampler require it to be static.
indices = torch.zeros( for all_greedy in [False, True]:
num_reqs_to_sample, generate_params_if_all_greedy = not all_greedy
dtype=torch.int32, sampling_metadata = (
device=device, TPUSupportedSamplingMetadata.from_input_batch(
) self.input_batch,
xm.mark_step() num_reqs,
sampling_meta = TPUSupportedSamplingMetadata.\ self.device,
from_input_batch(self.input_batch, indices) generate_params_if_all_greedy,
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, ))
num_reqs_to_sample) sampling_metadata.all_greedy = all_greedy
out = self.sample_from_hidden(dummy_hidden, sampling_meta) self.sample_from_hidden(dummy_hidden, sampling_metadata)
out = out.cpu() logger.info(" -- num_seqs: %d", num_reqs)
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
break
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
num_reqs_to_sample + 1, self.max_num_reqs)
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start) logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("sampling") self._update_num_xla_graphs("sampling")
def capture_model(self) -> None:
"""
Precompile all the subgraphs with possible input shapes.
"""
# TODO: precompile encoder
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_sample_from_hidden()
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """
Initialize KV cache based on `kv_cache_config`. Initialize KV cache based on `kv_cache_config`.
@ -962,48 +982,55 @@ class TPUModelRunner:
compiled_model.original_code_object) compiled_model.original_code_object)
compiled_model.compiled_codes.clear() compiled_model.compiled_codes.clear()
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def select_hidden_states(self, hidden_states, indices_do_sample):
return hidden_states[indices_do_sample]
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_hidden( def sample_from_hidden(
self, self,
hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata, sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Sample with xla-friendly function. This function is to be traced Sample with xla-friendly function. This function is to be traced
separately for lighter compilation overhead. separately from `forward` for lighter compilation overhead.
""" """
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
# SamplingMetadata here for pruning output in LogitsProcessor, disabled.
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
if sampling_metadata.all_greedy:
def sample( out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
logits: torch.Tensor, else:
sampling_metadata: TPUSupportedSamplingMetadata out_tokens = self.sampler(logits,
) -> SamplerOutput: sampling_metadata).sampled_token_ids
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens = torch.where(
sampling_metadata.all_greedy,
torch.argmax(logits, dim=-1, keepdim=True),
sample(logits, sampling_metadata).sampled_token_ids)
return out_tokens return out_tokens
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
def _get_padded_number(n: int, multiple: int) -> int: def get_input_embeddings(self, *args, **kwargs):
return ((n + multiple - 1) // multiple) * multiple return self.model.get_input_embeddings(*args, **kwargs)
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
logger.info("Preparing request paddings:")
# assert min_req_size is power of 2
assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
paddings: list = []
num = max(MIN_NUM_SEQS, min_req_size)
while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
paddings.append(num)
logger.info(" %d", num)
num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
return paddings
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit) return min(res, upper_limit)
def _get_paddings(min_token_size: int, max_token_size: int, def _get_token_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]: padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size, """Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size ending with a number that can cover max_token_size
@ -1013,18 +1040,20 @@ def _get_paddings(min_token_size: int, max_token_size: int,
first increase the size to twice, first increase the size to twice,
then increase the padding size by padding_gap. then increase the padding size by padding_gap.
""" """
# assert min_token_size is power of 2
assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
paddings = [] paddings = []
num = min_token_size num = min_token_size
if padding_gap == 0: if padding_gap == 0:
logger.info("Using exponential paddings:") logger.info("Using exponential token paddings:")
while num <= max_token_size: while num <= max_token_size:
logger.info(" %d", num) logger.info(" %d", num)
paddings.append(num) paddings.append(num)
num *= 2 num *= 2
else: else:
logger.info("Using incremental paddings:") logger.info("Using incremental token paddings:")
while num <= padding_gap: while num <= padding_gap:
logger.info(" %d", num) logger.info(" %d", num)
paddings.append(num) paddings.append(num)