文章目录
0. 概述
Gradient checkpointing 的核心思想是不保存所有层的激活值,而是只保存一部分关键点的激活值。当需要计算某个特定层的梯度时,如果该层的激活值没有被直接保存,那么可以通过重新计算从最近的关键点到该层的前向传播来获得这些激活值。这样做的代价是增加了计算量,因为部分前向传播过程需要重复执行,但可以显著降低内存使用。
下图是一个具有n层的简单前馈神经网络计算图:
其中:
- f f f 表示前向传播的激活计算节点
- b b b 表示反向传播的梯度计算节点
1. 简单反向传播
1.1 整体流程
简单反向传播(Vanilla backpropagation)中:
- 为了在反向传播阶段能够高效地计算梯度,前向传播时所有 f f f 节点都保存在内存中;
- 只有当反向传播进行到足以计算出 f f f 节点的所有依赖项或子节点时,才能将其从内存中删除。
这种方式意味着简单的反向传播所需的内存随着神经网络层数n成线性增长。
执行顺序和使用内存如下:
其中:
- 紫色的圆圈表示保留在内存的计算节点;
- 箭头的指向表示节点的依赖关系,例如:A --> B 表示 B 节点的计算依赖 A 的数据。
1.2 详细说明
为了