[DOC][TPU] Add core idea about avoiding recompilation after warmup (#16614)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
c64ee87267
commit
1eb3c2ed48
@ -53,6 +53,41 @@ INVALID_TOKEN_ID = -1
|
|||||||
MIN_NUM_SEQS = 8
|
MIN_NUM_SEQS = 8
|
||||||
|
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Ways to avoid recompilation
|
||||||
|
#########################################################
|
||||||
|
#
|
||||||
|
# The model executor has two primary components:
|
||||||
|
# 1. preparing the model and sampler inputs
|
||||||
|
# 2. executing the model and sampler.
|
||||||
|
# The core idea is to avoid any TPU computation during input preparation. For
|
||||||
|
# better compilation tracking and increased flexibility, the model execution and
|
||||||
|
# sampler are divided into several distinct components.
|
||||||
|
#
|
||||||
|
# Below are the detailed steps:
|
||||||
|
#
|
||||||
|
# Step 1
|
||||||
|
# It is recommended to avoid TPU operations when preparing the model and sampler
|
||||||
|
# inputs. CPU tensors can be prepared and transferred to the XLA device using
|
||||||
|
# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids
|
||||||
|
# compilation.
|
||||||
|
#
|
||||||
|
# Step 2
|
||||||
|
# The TPU execution should be decomposed into subgraphs (4 at the moment):
|
||||||
|
# 1. the main model
|
||||||
|
# 2. selecting hidden states for each request
|
||||||
|
# 3. sampler
|
||||||
|
# 4. encoder.
|
||||||
|
# Each subgraph should be decorated in a torch.compile. This is used to make
|
||||||
|
# sure that we have the same subgraph topology in both dummy_run and
|
||||||
|
# xecute_model. The results from these subgraphs should either be passed to
|
||||||
|
# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for
|
||||||
|
# subsequent processing on the CPU.
|
||||||
|
#
|
||||||
|
# Step 3
|
||||||
|
# The dummy_run should be comprehensive, ensuring all potential input shapes and
|
||||||
|
# branch predictions are included as subgraph inputs to facilitate
|
||||||
|
# pre-compilation.
|
||||||
class TPUModelRunner:
|
class TPUModelRunner:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user