[TPU] Set per-rank XLA cache (#7533)
This commit is contained in:
parent
2ecf7b1757
commit
951fdd66d3
@ -102,12 +102,12 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
||||||
torch._dynamo.config.cache_size_limit = 128
|
torch._dynamo.config.cache_size_limit = 128
|
||||||
# Use persistent cache to avoid XLA recompilation.
|
# Use persistent cache to avoid XLA recompilation.
|
||||||
# NOTE(woosuk): This does not completely eliminate the recompilation
|
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||||
# overhead because dynamo does not cache the compiled results.
|
# can have slightly different XLA graphs.
|
||||||
# NOTE(woosuk): Set readonly=False only for the rank 0 process to avoid
|
world_size = self.parallel_config.world_size
|
||||||
# race conditions.
|
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
|
||||||
xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH,
|
f"tp{world_size}_rank{self.rank}")
|
||||||
readonly=not self.is_driver_worker)
|
xr.initialize_cache(per_rank_path, readonly=False)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
self.model_runner.load_model()
|
self.model_runner.load_model()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user