[TPU] Correctly profile peak memory usage & Upgrade PyTorch XLA (#9438)
This commit is contained in:
parent
6aa6020f9b
commit
211fe91aa8
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20240828"
|
||||
ARG NIGHTLY_DATE="20241017"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
@ -56,8 +56,8 @@ First, install the dependencies:
|
||||
$ pip uninstall torch torch-xla -y
|
||||
|
||||
$ # Install PyTorch and PyTorch XLA.
|
||||
$ export DATE="20240828"
|
||||
$ export TORCH_VERSION="2.5.0"
|
||||
$ export DATE="20241017"
|
||||
$ export TORCH_VERSION="2.6.0"
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
|
@ -133,18 +133,19 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
dtype_btyes = get_dtype_size(self.cache_dtype)
|
||||
block_size = self.cache_config.block_size
|
||||
block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
|
||||
head_size * num_kv_heads)
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
usable_memory_size = int(total_memory_size *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
profiled = m["bytes_used"] # Weights + intermediate activations.
|
||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
dtype_btyes = get_dtype_size(self.cache_dtype)
|
||||
block_size_bytes = (dtype_btyes * self.cache_config.block_size *
|
||||
num_layers * 2 * head_size * num_kv_heads)
|
||||
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
|
||||
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user