文章目录
本篇文章主要介绍 torch.nn.Module
、torch.nn.Sequential()
、torch.nn.ModuleList
的使用方法与区别。
1. nn.Module
1.1 基本使用
在PyTorch中,nn.Module
类扮演着核心角色,它是构建任何自定义神经网络层、复杂模块或完整神经网络架构的基础构建块。通过继承 nn.Module
并在其子类中定义模型结构和前向传播逻辑(forward() 方法),开发者能够方便地搭建并训练深度学习模型。
在自定义一个新的模型类时,通常需要:
- 继承
nn.Module
类 - 重新实现
__init__
构造函数 - 重新实现
forward
方法
实现代码如下:
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
# nn.Module的子类函数必须在构造函数中执行父类的构造函数
def __init__(self):
super(Model, self).__init__() # 等价与nn.Module.__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
model=Model()
print(model)
输出如下:
注意:
- 一般把网络中具有可学习参数的层(如全连接层、卷积层)放在构造函数
__init__()
中 - forward() 方法必须重写,它是实现模型的功能,实现各个层之间连接关系的核心
nn.Module
类中的关键属性和方法包括:
-
初始化 (init):在类的初始化方法中定义并实例化所有需要的层、参数和其他组件。
在实现自己的MyModel类时继承了nn.Module,在构造函数中要调用Module的构造函数super(MyModel,self).init()
-
前向传播 (forward):实现前向传播函数来描述输入数据如何通过网络产生输出结果。
因为parameters是自动求导,所以调用forward()后,不用自己写和调用backward()函数。而且一般不是显式的调用forward(layer.farword),而是layer(input),会自执行forward()。 -
管理参数和模块:
- 使用 .parameters() 访问模型的所有可学习参数。
- 使用 add_module() 添加子模块,并给它们命名以便于访问。
- 使用 register_buffer() 为模型注册非可学习的缓冲区变量。
- 训练与评估模式切换:
- 使用 model.train() 将模型设置为训练模式,这会影响某些层的行为,如批量归一化层和丢弃层。
- 使用 model.eval() 将模型设置为评估模式,此时会禁用这些依赖于训练阶段的行为。
- 保存和加载模型状态:
- 调用 model.state_dict() 获取模型权重和优化器状态的字典形式。
- 使用 torch.save() 和 torch.load() 来保存和恢复整个模型或者仅其状态字典。
- 通过 model.load_state_dict(state_dict) 加载先前保存的状态字典到模型中。
此外,nn.Module 还提供了诸如移动模型至不同设备(CPU或GPU)、零化梯度等实用功能,这些功能在整个模型训练过程中起到重要作用。
1.2 常用函数
torch.nn.Module 这个类的内部有多达 48 个函数,下面就一些比较常用的函数进行讲解。
1.2.1 核心函数
-
__init__
函数 和forward()
函数
__init__中主要是初始化一些内部需要用到的state;forward在这里没有具体实现,是需要在各个子类中实现的,如果子类中没有实现就会报错raise NotImplementedError。 -
apply(fn)
函数
将Module及其所有的SubModule传进给定的fn函数操作一遍。我们可以用这个函数来对Module的网络模型参数用指定的方法初始化。下边这个例子就是将网络模型net中的子模型Linear的参数全部赋值为 1 。
def init_weight(m):
if type(m) == nn.Linear:
m.weight.data.fill_(1.0)
m.bias.data.fill_(0)
net = nn.Sequential(nn.Linear(2, 2))
net.apply(init_weight)
输出如下:
state_dict()
函数
返回一个包含module的所有state的dictionary,而这个字典的Keys对应的就是parameter和buffer的名字names。该函数的源码部分有一个循环可以递归遍历Module中所有的SubModule。
net = torch.nn.Linear(2, 2)
print(net.state_dict())
输出如下:
print(net.state_dict().keys())
add_module()
函数
将子模块加入当前模块中,被添加的模块可以用name来获取
1.2.2 查看函数
使用 nn.Module