[TPU][V1] Refine tpu_model_runner to mitigate future recompilation issues (#16275)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
1bff42c4b7
commit
a454748544
@ -44,7 +44,7 @@ def test_tpu_compilation():
|
||||
assert generated_text.startswith(answer)
|
||||
|
||||
compiled_codes = sorted(
|
||||
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
|
||||
glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py")))
|
||||
|
||||
for i, compiled_code in enumerate(compiled_codes):
|
||||
print("{} file: {}".format(i + 1, compiled_code))
|
||||
@ -52,15 +52,21 @@ def test_tpu_compilation():
|
||||
# We should only trigger Dynamo compilation 2 times:
|
||||
# 1. Forward pass without kv_caches
|
||||
# 2. Forward pass with kv_caches
|
||||
# Check we have 4 compiled codes
|
||||
# Check we have 2 compiled codes
|
||||
assert len(compiled_codes) == 2
|
||||
|
||||
kv_cache_prefix = "kv_cache"
|
||||
attn_prefix = "ragged_paged_attention"
|
||||
|
||||
def extract_compiled_index(s):
|
||||
parts = s.replace(".", "_").split("_")
|
||||
numbers = [int(part) for part in parts if part.isdigit()]
|
||||
return numbers[0]
|
||||
|
||||
# Check all the compilations are as expected
|
||||
compiled_fns = sorted(
|
||||
glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))
|
||||
compiled_fns = sorted(glob.glob(
|
||||
os.path.join(temp_dir, "__compiled_fn*Captured*.py")),
|
||||
key=lambda s: extract_compiled_index(s))
|
||||
|
||||
for i, compiled_fn in enumerate(compiled_fns):
|
||||
print("{} file: {}".format(i + 1, compiled_fn))
|
||||
|
@ -7,9 +7,9 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
|
||||
_get_padded_token_len,
|
||||
_get_paddings)
|
||||
from vllm.v1.worker.tpu_model_runner import (
|
||||
TPUModelRunner, _get_padded_num_reqs_with_upper_limit,
|
||||
_get_padded_token_len, _get_req_paddings, _get_token_paddings)
|
||||
|
||||
# Mock torch_xla module since it may not be available in the test environments
|
||||
torch_xla_patcher = mock.patch.dict(
|
||||
@ -296,16 +296,29 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
def test_get_paddings():
|
||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||||
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
|
||||
actual_paddings = _get_paddings(min_token_size, max_token_size,
|
||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
|
||||
padding_gap)
|
||||
assert actual_paddings == expected_paddings
|
||||
|
||||
|
||||
def test_get_padded_token_len():
|
||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||||
paddings = _get_paddings(min_token_size, max_token_size, padding_gap)
|
||||
paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
||||
assert _get_padded_token_len(paddings, 1) == 16
|
||||
assert _get_padded_token_len(paddings, 16) == 16
|
||||
assert _get_padded_token_len(paddings, 20) == 32
|
||||
assert _get_padded_token_len(paddings, 300) == 320
|
||||
assert _get_padded_token_len(paddings, 512) == 512
|
||||
|
||||
|
||||
def test_get_padded_num_reqs_with_upper_limit():
|
||||
assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
|
||||
assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
|
||||
assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
|
||||
assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
|
||||
|
||||
|
||||
def test_get_req_paddings():
|
||||
assert _get_req_paddings(1, 32) == [8, 16, 32]
|
||||
assert _get_req_paddings(8, 32) == [8, 16, 32]
|
||||
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
|
||||
|
@ -3,7 +3,6 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
@ -24,15 +23,15 @@ class TPUSupportedSamplingMetadata:
|
||||
# This class exposes a more xla-friendly interface than SamplingMetadata
|
||||
# on TPU, in particular all arguments should be traceable and no optionals
|
||||
# are allowed, to avoid graph recompilation on Nones.
|
||||
temperature: torch.Tensor
|
||||
temperature: torch.Tensor = None
|
||||
|
||||
min_p: torch.Tensor
|
||||
min_p: torch.Tensor = None
|
||||
# Still too slow on forward_native!
|
||||
top_k: torch.Tensor = None
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
# Greedy sampling flag for compiling single xla graph.
|
||||
all_greedy: torch.Tensor = None
|
||||
all_greedy: bool = True
|
||||
|
||||
# Generator not supported by xla
|
||||
generators: dict[int,
|
||||
@ -57,64 +56,58 @@ class TPUSupportedSamplingMetadata:
|
||||
|
||||
allowed_token_ids_mask = None
|
||||
bad_words_token_ids = None
|
||||
indices_do_sample: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
def from_input_batch(
|
||||
cls, input_batch: InputBatch,
|
||||
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
|
||||
cls,
|
||||
input_batch: InputBatch,
|
||||
padded_num_reqs: int,
|
||||
xla_device: torch.device,
|
||||
generate_params_if_all_greedy: bool = False
|
||||
) -> "TPUSupportedSamplingMetadata":
|
||||
"""
|
||||
Copy sampling tensors slices from `input_batch` to on device tensors.
|
||||
|
||||
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
|
||||
slices dynamic shapes on device tensors. This impl moves the dynamic
|
||||
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
|
||||
also reuses the on-device persistent tensors managed in `input_batch`
|
||||
to reduce waste.
|
||||
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
|
||||
|
||||
`indices_do_sample` contains the indices to be fed to the Sampler,
|
||||
normally one per request, here padded to the closest pre-compiled shape
|
||||
We expect sampling params tensors to be padded to the same fixed shape.
|
||||
|
||||
Eg. 3 requests, tensors padded to 4
|
||||
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
|
||||
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
|
||||
Args:
|
||||
input_batch: The input batch containing sampling parameters.
|
||||
padded_num_reqs: The padded number of requests.
|
||||
xla_device: The XLA device.
|
||||
generate_params_if_all_greedy: If True, generate sampling parameters
|
||||
even if all requests are greedy. this is useful for cases where
|
||||
we want to pre-compile a graph with sampling parameters, even if
|
||||
they are not strictly needed for greedy decoding.
|
||||
"""
|
||||
num_reqs = input_batch.num_reqs
|
||||
padded_num_reqs = len(indices_do_sample)
|
||||
# Early return to avoid unnecessary cpu to tpu copy
|
||||
if (input_batch.all_greedy is True
|
||||
and generate_params_if_all_greedy is False):
|
||||
return cls(all_greedy=True)
|
||||
|
||||
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
|
||||
fill_val) -> torch.Tensor:
|
||||
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
|
||||
num_reqs = input_batch.num_reqs
|
||||
|
||||
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
|
||||
# Pad value is the default one.
|
||||
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
||||
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
|
||||
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
|
||||
|
||||
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
|
||||
# consistent. We can't have flags to skip copies or we'll end up
|
||||
# recompiling.
|
||||
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
|
||||
fill_slice(input_batch.temperature_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
# TODO Temporarily disabled until sampling options are enabled
|
||||
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
|
||||
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
|
||||
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
|
||||
# fill_slice(input_batch.top_p_cpu_tensor)
|
||||
# fill_slice(input_batch.top_k_cpu_tensor)
|
||||
fill_slice(input_batch.min_p_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||
return cls(
|
||||
temperature=input_batch.temperature[:padded_num_reqs],
|
||||
# Scalar tensor for xla-friendly tracing.
|
||||
all_greedy=torch.tensor(input_batch.all_greedy,
|
||||
dtype=torch.bool,
|
||||
device=input_batch.device),
|
||||
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
|
||||
to(xla_device),
|
||||
all_greedy=input_batch.all_greedy,
|
||||
# TODO enable more and avoid returning None values
|
||||
top_p=None, # input_batch.top_p[:padded_num_reqs],
|
||||
top_k=None, # input_batch.top_k[:padded_num_reqs],
|
||||
min_p=input_batch.min_p[:padded_num_reqs],
|
||||
generators=input_batch.generators,
|
||||
indices_do_sample=indices_do_sample)
|
||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
generators=input_batch.generators)
|
||||
|
@ -32,7 +32,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec, SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput, SamplerOutput)
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
@ -177,10 +177,12 @@ class TPUModelRunner:
|
||||
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
||||
# Used to initialize positions / context_lens / seq_lens
|
||||
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
||||
self.num_tokens_paddings = _get_paddings(
|
||||
self.num_tokens_paddings = _get_token_paddings(
|
||||
min_token_size=16,
|
||||
max_token_size=self.max_num_tokens,
|
||||
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
||||
self.num_reqs_paddings = _get_req_paddings(
|
||||
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
|
||||
|
||||
def _update_num_xla_graphs(self, case_str):
|
||||
check_comp = self.check_recompilation and not self.enforce_eager
|
||||
@ -508,7 +510,7 @@ class TPUModelRunner:
|
||||
# Padded to avoid recompiling when `num_reqs` varies.
|
||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||
logits_indices = logits_indices.to(self.device)
|
||||
return attn_metadata, logits_indices
|
||||
return attn_metadata, logits_indices, padded_num_reqs
|
||||
|
||||
def _scatter_placeholders(
|
||||
self,
|
||||
@ -663,7 +665,8 @@ class TPUModelRunner:
|
||||
mm_embeds = []
|
||||
|
||||
# Prepare inputs
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
|
||||
scheduler_output)
|
||||
if self.is_multimodal_model:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
@ -682,11 +685,6 @@ class TPUModelRunner:
|
||||
input_ids = self.input_ids
|
||||
inputs_embeds = None
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
|
||||
# are copied to device in chunks of pre-compiled padded shape to
|
||||
# avoid recompilations.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, logits_indices)
|
||||
# Run the decoder
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
@ -694,6 +692,10 @@ class TPUModelRunner:
|
||||
positions=self.position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = self.select_hidden_states(hidden_states,
|
||||
logits_indices)
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, padded_num_reqs, self.device)
|
||||
selected_token_ids = self.sample_from_hidden(hidden_states,
|
||||
tpu_sampling_metadata)
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
@ -857,60 +859,78 @@ class TPUModelRunner:
|
||||
inputs_embeds=inputs_embeds)
|
||||
self._hidden_states_dtype = out.dtype
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
|
||||
def _precompile_backbone(self) -> None:
|
||||
logger.info("Compiling the model with different input shapes.")
|
||||
|
||||
start = time.perf_counter()
|
||||
for num_tokens in self.num_tokens_paddings:
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("model")
|
||||
self._update_num_xla_graphs("model backbone")
|
||||
|
||||
def _precompile_select_hidden_states(self) -> None:
|
||||
# Compile hidden state selection function for bucketed
|
||||
# n_tokens x max_num_reqs. Graph is really small so this is fine.
|
||||
logger.info(
|
||||
"Compiling select_hidden_states with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
hsize = self.model_config.get_hidden_size()
|
||||
for num_tokens in self.num_tokens_paddings:
|
||||
dummy_hidden = torch.zeros((num_tokens, hsize),
|
||||
device=self.device,
|
||||
dtype=self._hidden_states_dtype)
|
||||
torch._dynamo.mark_dynamic(dummy_hidden, 0)
|
||||
for num_reqs in self.num_reqs_paddings:
|
||||
indices = torch.zeros(num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
torch._dynamo.mark_dynamic(indices, 0)
|
||||
self.select_hidden_states(dummy_hidden, indices)
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("select_hidden_states")
|
||||
|
||||
def _precompile_sample_from_hidden(self) -> None:
|
||||
logger.info("Compiling sampling with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
hsize = self.model_config.get_hidden_size()
|
||||
device = self.device
|
||||
# Compile sampling step for different model+sampler outputs in bucketed
|
||||
# n_tokens x max_num_reqs. Graph is really small so this is fine.
|
||||
for num_tokens in self.num_tokens_paddings:
|
||||
num_reqs_to_sample = MIN_NUM_SEQS
|
||||
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||
device=device,
|
||||
for num_reqs in self.num_reqs_paddings:
|
||||
dummy_hidden = torch.zeros((num_reqs, hsize),
|
||||
device=self.device,
|
||||
dtype=self._hidden_states_dtype)
|
||||
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
|
||||
while True:
|
||||
indices = torch.zeros(
|
||||
num_reqs_to_sample,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
xm.mark_step()
|
||||
sampling_meta = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, indices)
|
||||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||
num_reqs_to_sample)
|
||||
out = self.sample_from_hidden(dummy_hidden, sampling_meta)
|
||||
out = out.cpu()
|
||||
# Requests can't be more than tokens. But do compile for the
|
||||
# next bigger value in case num_tokens uses bucketed padding.
|
||||
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
|
||||
break
|
||||
# Make sure to compile the `max_num_reqs` upper-limit case
|
||||
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
|
||||
num_reqs_to_sample + 1, self.max_num_reqs)
|
||||
# The first dimension of dummy_hidden cannot be mark_dynamic because
|
||||
# some operations in the sampler require it to be static.
|
||||
for all_greedy in [False, True]:
|
||||
generate_params_if_all_greedy = not all_greedy
|
||||
sampling_metadata = (
|
||||
TPUSupportedSamplingMetadata.from_input_batch(
|
||||
self.input_batch,
|
||||
num_reqs,
|
||||
self.device,
|
||||
generate_params_if_all_greedy,
|
||||
))
|
||||
sampling_metadata.all_greedy = all_greedy
|
||||
self.sample_from_hidden(dummy_hidden, sampling_metadata)
|
||||
logger.info(" -- num_seqs: %d", num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("sampling")
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""
|
||||
Precompile all the subgraphs with possible input shapes.
|
||||
"""
|
||||
# TODO: precompile encoder
|
||||
self._precompile_backbone()
|
||||
self._precompile_select_hidden_states()
|
||||
self._precompile_sample_from_hidden()
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@ -962,47 +982,54 @@ class TPUModelRunner:
|
||||
compiled_model.original_code_object)
|
||||
compiled_model.compiled_codes.clear()
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def select_hidden_states(self, hidden_states, indices_do_sample):
|
||||
return hidden_states[indices_do_sample]
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def sample_from_hidden(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample with xla-friendly function. This function is to be traced
|
||||
separately for lighter compilation overhead.
|
||||
separately from `forward` for lighter compilation overhead.
|
||||
"""
|
||||
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
|
||||
sample_hidden_states = \
|
||||
hidden_states[sampling_metadata.indices_do_sample]
|
||||
# SamplingMetadata here for pruning output in LogitsProcessor, disabled.
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
def sample(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata
|
||||
) -> SamplerOutput:
|
||||
sampler_out = self.sampler(logits, sampling_metadata)
|
||||
return sampler_out
|
||||
|
||||
# Optimized greedy sampling branch, tracing both paths in a single pass
|
||||
# NOTE all_greedy is a scalar, this is just an optimized if/else.
|
||||
out_tokens = torch.where(
|
||||
sampling_metadata.all_greedy,
|
||||
torch.argmax(logits, dim=-1, keepdim=True),
|
||||
sample(logits, sampling_metadata).sampled_token_ids)
|
||||
if sampling_metadata.all_greedy:
|
||||
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
else:
|
||||
out_tokens = self.sampler(logits,
|
||||
sampling_metadata).sampled_token_ids
|
||||
return out_tokens
|
||||
|
||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||
|
||||
def _get_padded_number(n: int, multiple: int) -> int:
|
||||
return ((n + multiple - 1) // multiple) * multiple
|
||||
def get_input_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_input_embeddings(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
|
||||
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
|
||||
logger.info("Preparing request paddings:")
|
||||
# assert min_req_size is power of 2
|
||||
assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
|
||||
paddings: list = []
|
||||
num = max(MIN_NUM_SEQS, min_req_size)
|
||||
while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
|
||||
paddings.append(num)
|
||||
logger.info(" %d", num)
|
||||
num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
|
||||
return paddings
|
||||
|
||||
|
||||
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
|
||||
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
|
||||
return min(res, upper_limit)
|
||||
|
||||
|
||||
def _get_paddings(min_token_size: int, max_token_size: int,
|
||||
def _get_token_paddings(min_token_size: int, max_token_size: int,
|
||||
padding_gap: int) -> list[int]:
|
||||
"""Generate a list of padding size, starting from min_token_size,
|
||||
ending with a number that can cover max_token_size
|
||||
@ -1013,18 +1040,20 @@ def _get_paddings(min_token_size: int, max_token_size: int,
|
||||
first increase the size to twice,
|
||||
then increase the padding size by padding_gap.
|
||||
"""
|
||||
# assert min_token_size is power of 2
|
||||
assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
|
||||
paddings = []
|
||||
num = min_token_size
|
||||
|
||||
if padding_gap == 0:
|
||||
logger.info("Using exponential paddings:")
|
||||
logger.info("Using exponential token paddings:")
|
||||
while num <= max_token_size:
|
||||
logger.info(" %d", num)
|
||||
paddings.append(num)
|
||||
num *= 2
|
||||
|
||||
else:
|
||||
logger.info("Using incremental paddings:")
|
||||
logger.info("Using incremental token paddings:")
|
||||
while num <= padding_gap:
|
||||
logger.info(" %d", num)
|
||||
paddings.append(num)
|
||||
|
Loading…
x
Reference in New Issue
Block a user