深度学习训练的扩展与管理技巧
立即解锁
发布时间: 2025-08-30 00:38:27 阅读量: 11 订阅数: 11 AIGC 


PyTorch Lightning实战指南
### 深度学习训练的扩展与管理技巧
#### 1. 混合精度训练
在深度学习模型训练中,像卷积神经网络(CNNs)这类模型会将高维对象(如图像)转换为低维对象(如张量)。实际上,我们不需要高精度的计算,如果牺牲一点精度,就能大幅提升训练速度。TPU 提升性能的方法之一就是利用了这一概念,但在非 TPU 环境中,这种功能通常不可用。
不过,使用像 GPU 或多 CPU 这样更好的处理单元能显著提升训练性能。此外,我们还能在框架层面启用低精度训练。PyTorch Lightning 允许通过给 `Trainer` 模块传递一个名为 `precision` 的额外参数来使用混合精度训练。
混合精度训练使用两种不同位数的浮点数——32 位和 16 位。这样可以减少模型训练期间的内存占用,进而提高训练性能。PyTorch Lightning 支持在 CPU、GPU 和 TPU 上进行混合精度训练。示例代码如下:
```python
import pytorch_lightning as pl
...
trainer = pl.Trainer(gpus=-1, precision=16)
```
这种方式能让模型训练得更快,训练一个 epoch 的时间甚至能减半。有实验表明,使用 16 位精度可使 CNN 模型的训练速度提高约 30 - 40%。
#### 2. 训练过程中的控制与管理
在深度学习模型的训练过程中,常常需要审计、平衡和控制机制。例如,当训练一个需要 1000 个 epoch 的模型时,在训练到 500 个 epoch 时网络故障导致训练中断,如何从特定点恢复训练,同时确保不丢失之前的进度,或者从云环境中保存模型检查点,都是工程师们常见的实际挑战。
##### 2.1 在云环境中保存模型检查点
像 Google Colab 这样的云环境中的笔记本有资源限制和空闲超时时间。如果在模型开发过程中超过这些限制,笔记本会被停用。由于云环境具有弹性,当笔记本停用时,其底层的计算和存储资源会被释放。刷新停用笔记本的浏览器窗口时,通常会使用全新的计算/存储资源重新启动。
云环境中的笔记本因资源限制或空闲超时停用后,检查点目录和文件将无法访问。解决这个问题的一种方法是使用挂载的驱动器。
PyTorch Lightning 会自动将最后一个训练 epoch 的状态保存到检查点中,默认保存在当前工作目录。但由于云笔记本底层基础设施的临时性,当前工作目录不是保存检查点的好选择。在这种环境下,云提供商通常会提供可从笔记本访问的持久存储。
以 Google Colab 为例,使用 Google Drive 持久存储的步骤如下:
1. 导入 Google Drive 到笔记本:
```python
from google.colab import drive
drive.mount('/content/gdrive')
```
2. 在执行 `drive.mount()` 语句时,会提示进行身份验证,按照提示操作。
3. 点击“Connect to Google Drive”按钮,在弹出窗口中选择你的账户,然后点击“Allow”。
4. 使用 PyTorch Lightning `Trainer` 模块的 `default_root_dir` 参数将检查点路径更改为 Google Drive。示例代码如下:
```python
import pytorch_lightning as pl
...
ckpt_dir = "/content/gdrive/MyDrive/Colab Notebooks/cnn"
trainer = pl.Trainer(default_root_dir=ckpt_dir,
max_epochs=100)
```
检查点将存储在 `/content/gdrive/MyDrive/Colab Notebooks/cnn/lightning_logs/version_<number>/checkpoints` 目录结构下。
##### 2.2 更改检查点功能的默认行为
PyTorch Lightning 框架默认将最后一个训练 epoch 的状态保存到当前工作目录。为了让用户更改这种默认行为,框架在 `pytorch_lightning.callbacks` 中提供了 `ModelCheckpoint` 类。以下是一些自定义示例:
- 保存每第 n 个 epoch 的状态:
```python
import pytorch_lightning as pl
...
ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_epochs=10)
trainer = pl.Trainer(callbacks=[ckpt_callback],
max_epochs=100)
``
```
0
0
复制全文
相关推荐









