Memory Footprint of a Neural Net During Backpropagation
In this article we discuss the memory footprint of a neural network during backpropagation, how backprop works, what affects the memory footprint, code to demonstrate the memory footprint, and more.
Backpropagation
Let’s have a look at a three layer network backpropagation. Let’s denote as $f(\cdot)$ a general function of the inner parameters, $x_1$ as the input to the first layer, $x_2$ the input to the second layer (and the output of the first layer), and $x_4$ the output of the network. So we have: $x_2=f(x1, w1) \quad|\quad x_3=f(x2,w2) \quad|\quad x_4=f(x_3,w_3) \quad|\quad L=f(x4) $
Now, if we want to find the gradient of $w_1$ we have: $ \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial x_4} \frac{\partial x_4}{\partial x_3} \frac{\partial x_3}{\partial x_2} \frac{\partial x_2}{\partial w_1}$
Let’s mark the parameters of $f()$ in bold when they are essential in the general case, and in light font when they are sometimes needed, depending on the actual function. For example, if $x_4=x_3 w_3$ then $ \frac{\partial x_4}{\partial x_3}=w_3$ therefore the derivative is only a function of $w_3$ and does not depend on the layer input activation $x_3$. However, if $x_4=\sigma (x_3 w_3)$ then the derivative depends on both variables, where $\sigma$ is sigmoid or ReLU. That’s what makes nonlinearity non-linear, after all. So, let’s look at all the terms in the chain rule, and for each one of them, analyze if it always depends on the input to the layer (x) or only sometimes:
\[\newcommand{\mb}[1]{\mathbf{#1}} \newcommand{\mi}[1]{\textit{#1}} \frac{\partial L}{\partial w_1} = \underbrace{ \frac{\partial L}{\partial x_4}}_{f(\mathbf{x_4})} \underbrace{ \frac{\partial x_4}{\partial x_3} }_{f(x_3, \mb{w_3})} \underbrace{\frac{\partial x_3}{\partial x_2} }_{f(x_2, \mb{w_2})} \underbrace{ \frac{\partial x_2}{\partial w_1}}_{f(\mb{x_1}, w_1)}\]Observation 1: We can see that the first term (that represents the last layer), which is the impact of the output of the network to the loss, always depends on the output of the network, which is obvious. In the last term, which is the layer we want to optimize, the input to the layer is mandatory, unless the weight we’re interested in is the bias, for example, and then the derivative does not depend on the input.
Observation 2: What information do we need to store, before the backprop starts, in order to update $w_1$? We can see that in the general case, we need all layer inputs/outputs, meaning all $x_1 .. x_4$. That means that after the forward() pass, we must store all activations of the network: the input, the hidden representations, and the output. However in some cases, for example if one of the inner layers is a linear layer with no nonlinearity, we do not need to store the input for the layer during the forward pass. This will be demonstrated using the code below.
Combining the two observations, we can conclude that in the special case, where (1) we have a layer which is frozen (meaning that we do not want to optimize its weights) and; (2) the frozen layer is a linear layer; we can choose not to store the activation during forward pass, since it is unneeded for the optimization of the layer (since its frozen), and not needed for the update of upstream weights in the DAG/computation graph.
Furthermore, for efficiency, the backprop starts from the end, here are the steps:
-
We first compute and hold as state $s_4 = \underbrace{ \frac{\partial L}{\partial x_4}}_{f(\mathbf{x_4})} $ , now we can release the activation $x_4$ from memory as we will not use it anymore.
-
If the third layer is unfrozen, update the weight \(\nabla w_3 = s_4 \underbrace{ \frac{\partial x_4}{\partial w_3}}_{f(\mb{x_3},w_3)}\), then compute $s_3 = s_4 \underbrace{ \frac{\partial x_4}{\partial x_3} }_{f(x_3, \mb{w_3})} $ , now we can release the activation $x_3$ from memory. We can see that if a layer is both frozen and linear, we do not use the activation $x_3$ at all, and in this case we do not need to store it in first place.
-
If the second layer is unfrozen, update the weight \(\nabla w_2 = s_3 \underbrace{ \frac{\partial x_3}{\partial w_2}}_{f(\mb{x_2},w_2)}\). Compute \(s_2 = s_3 \underbrace{\frac{\partial x_3}{\partial x_2} }_{f(x_2, \mb{w_2})}\), now we can release the activation $x_2$ from memory.
-
If the first layer is unfrozen, update the weight \(\nabla w_1 = s_2 \underbrace{ \frac{\partial x_2}{\partial w_1}}_{f(\mb{x_1}, w_1)}\). We can release the input to the network $x_1$ and we’re done.
Observation 3: We do not need to hold input activations to a layer which is frozen and for which all the upstream weights (the ancestors DAG weights) are frozen too, since no one will use the computation of $s$.
To conclude:
Activation memory allocation: In cases where a layer is frozen AND (the derivative of its output w.r.t the input does not depend on the input, as in the linear case, OR all upstream dependant weights are frozen too), we can save memory and not store the input activation.
Activation memory deallocation: during backprop, we can release the activations we’ve already used, to free memory.
Network Memory Footprint
In the backprop, and in pytorch, we have five basic steps:
-
Loading the network to memory. If the network has 100M parameters and we use 32bit float per parameter, it will take 400MB.
-
Compute the $\mi{foward()}$ pass, and store some activations, depending on the conclusions above. Activations are stored for each sample in a batch, therefore the memory footprint depends on the batch size. If we train using mixed-precision, the forward activations are kept in 16bit instead of 32bit, so the footprint reduces by half.
-
Compute the $\mi{backward()}$ pass, allocated gradient storage per unfrozen parameters, use the activations we’ve stored to compute the gradients of unfrozen layers, and release used activations.
-
Running the optimizer for unfronzen layers $\mi{optimizer.step()}$, uses the gradients we calculated, store and update internal moments/optimizer state only for unfrozen layers. Batch size does not effect the memory allocation of the optimizer, since all gradients are summed in place, and when GPU is used, cores work in parallel to update the .grad of the tensors. If we use Adam optimizer, two moments will be kept for each parameter.
-
release the gradients we’ve accumulated using $\mi{zero_grad()}$
Demonstration Code
Run this code:
import torch
import torch.nn as nn
def test_memory(in_size=100, out_size=10, num_layers=200, freeze_start=0, freeze_end=0,
hidden_size=100, optimizer_type=torch.optim. Adam, batch_size=1,
device=0, add_relu=True):
sample_input = torch.randn(batch_size, in_size)
layers = [nn.Linear(in_size, hidden_size)]
for layer_index in range(num_layers):
layers_to_append = [nn.Linear(hidden_size, hidden_size, bias=False)]
if add_relu:
layers_to_append.append(nn.ReLU())
# Selectively freeze some layers
if freeze_start <= layer_index < freeze_end:
for layer in layers_to_append:
for param in layer.parameters():
param.requires_grad = False
layers.extend(layers_to_append)
layers.append(nn.Linear(hidden_size, out_size))
print(f"number of layers: {len(layers)}")
model = nn.Sequential(*layers)
optimizer = optimizer_type(model.parameters(), lr=.001)
start = torch.cuda.memory_allocated (device)
print("Starting at 0 memory usage as baseline.")
model.to(device)
after_model = torch.cuda.memory_allocated (device) - start
print(f"1: After model to device: {after_model:,}")
print("")
for i in range(3):
print("Iteration", i)
a = torch.cuda.memory_allocated(device) - start
# Running the forward pass. Here all activations will be saved, per every sample in batch
out = model(sample_input.to(device)).sum()
b = torch.cuda.memory_allocated(device) - start
print(f"2: Memory consumed after forward pass (activations stored, depends on batch size): {b:,} change: ", f'{b - a:,}' ) # batch * num layers * hidden_size * 4 bytes per float
# Backward step: Here we allocate (unless already allocated) and store the gradient of each non-frozen parameter,
# and we release/discard the activations which are descendants in the DAG as we go.
# So at the end the change in memory = +non-frozen parameters (if was unallocated) - non-degenerate activations
# gradients are accumulated in place in the .grad attribute of the tensors for which gradients are being computed. Each GPU core works on a different
# part of the .grad tensor, so they can all work in parallel
out.backward()
c = torch.cuda.memory_allocated(device) - start
print(f"3: After backward pass (activations released, grad stored) {c:,} change: {c-b:,}")
# Running the optimizer, at the first time, will store 2 moments for each non-frozen parameter (if using Adam), which will be kept throughout the training
# So change in memory, in the first time = 2 * non-frozen parameters
# optimizer changes the model parameters in place
optimizer.step()
d = torch.cuda.memory_allocated(device) - start
print(f"4: After optimizer step (moments stored at first time): {d:,} change: {d-c:,} " )
# zero_grad = Reset and release gradients tensors created in .backward()
model.zero_grad()
e = torch.cuda.memory_allocated(device) - start
print(f"5: After zero_grad step (grads released): {e:,} change: {e-d:,} " )
print("")
test_memory(optimizer_type=torch.optim.Adam, batch_size=64, freeze_start=0, freeze_end=0
, add_relu=False)
Let’s have a look at the second iteration, for example:
Iteration 2
2: Memory consumed after forward pass (activations stored, depends batch size): 46,616,576 change: 5,171,200
3: After backward pass (activations released, grad stored) 49,580,544 change: 2,963,968
4: After optimizer step (moments stored at first time): 49,580,544 change: 0
5: After zero_grad step (grads released): 41,445,376 change: -8,135,168
What’s going on? The forward() pass allocated 5M of activations memory. You can play with the batch_size and see how it affects the activation memory size. You can change freeze_end=200 and see the activation memory drops, however if you then set add_relu=True you can see the memory footprint goes up again, since the layers are not linear anymore, as we proved above.
The backward() allocated 8M for the gradients (like the unfrozen model size), but released 5M of activations. So at the end we see a net increase of 3M.
The optimizer step allocated nothing in this iteration, since it already allocated 16M at the first iteration, which is exactly two moments per each parameter in the model.
And the zero_grad released the 8M of gradients.
Peak Memory Consumption
So what is the peak memory consumption for a network? We have two places where we could potentially reach peak memory consumption.
- After forward(): model + 2 x model (for Adam optimizer) + activations (batch dependant)
- After backward(): model + 2 x model (for Adam optimizer) + gradients (non-frozen model params)
So the question depends on which component is more dominant in the specific network: the activations or the gradients? For example, in CNNs we may have a big activation space, even for a small parameter. In Transformers, the activations depend on the sequence length. Activations are batch dependent while gradient memory footprint only depends on the size of the weights.
So we can phrase the peak memory consumption:
model + 2 x model (for Adam optimizer) + MAX( gradients (non-frozen model params, can be multiplied by two if more than one gpu in training, for accumulation), activations [batch size * activations_per_batch * activation_precision] )
Methods to reduce memory footprint
Gradient Accumulation
We can do the optimizer.step() and optimizer.zero_grad() steps only once in a while, and essentially split a batch to sub batches. As seen above, this will reduce the activation memory allocation, but will not reduce the gradient memory allocation.
Gradient Checkpointing
This actually should have been called Activation Checkpointing. Instead of storing some activations which were computed in the forward() pass, only store a subset of them, thus reducing memory footprint, and re-compute the missing activations on the fly, only when needed during the backprop() computation.
\[\square\]