大家好,今天和大家分享一些如何使用 Pytorch 搭建 ResNet50 卷积神经网络模型,并使用迁移学习的思想训练网络,完成鸟类图片的预测。
ResNet 的原理和 TensorFlow2实现方式可以看我之前的两篇博文,这里就不详细说明原理了。
ResNet18、34:https://blog.csdn.net/dgvv4/article/details/122396424
1. 模型构建
首先导入网络构建过程中所有需要用到的工具包,本小节的所有代码写在 ResNet.py 文件中
import torch
from torch import nn
from torchstat import stat # 查看网络参数
from torchsummary import summary # 查看网络结构
1.1 构建单个残差块
残差单元的结构如下图所示,一种是基本模块,即输入特征图的尺寸和输出特征层的尺寸相同,两个特征图可以直接叠加;一种是下采样模块,即主干部分对输入特征图使用 stride=2 的下采样卷积,使得输入特征图的尺寸变成原来的一半,那么残差边部分也需要对输入特征图进行下采样操作,使得输入特征图经过残差边处理后的 shape 能和主干部分处理后的特征图 shape 相同,从而能够将残差边输出和主干输出直接叠加。
以下图基本残差块为例,先对输入图像使用 1*1 卷积下降通道数,在低维空间下使用 3*3 卷积提取特征,然后再使用 1*1 卷积上升通道数,残差连接输入和输出,将叠加后的结果进过 relu 激活函数。
代码如下:
# -------------------------------------------- #
#(1)残差单元
# x--> 卷积 --> bn --> relu --> 卷积 --> bn --> 输出
# |---------------Identity(短接)----------------|
'''
in_channel 输入特征图的通道数
out_channel 第一次卷积输出特征图的通道数
stride=1 卷积块中3*3卷积的步长
downsample 是否下采样
'''
# -------------------------------------------- #
class Bottleneck(nn.Module):
# 最后一个1*1卷积下降的通道数
expansion = 4
# 初始化
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
# 继承父类初始化方法
super(Bottleneck, self).__init__()
# 属性分配
# 1*1卷积下降通道,padding='same',若stride=1,则[b,in_channel,h,w]==>[b,out_channel,h,w]
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=1, stride=1, padding=0, bias=False)
# BN层是计算特征图在每个channel上面的均值和方差,需要给出输出通道数
self.bn1 = nn.BatchNorm2d(out_channel)
# relu激活, inplace=True节约内存
self.relu = nn.ReLU(inplace=True)
# 3*3卷积提取特征,[b,out_channel,h,w]==>[b,out_channel,h,w]
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
# BN层, 有bn层就不需要bias偏置
self.bn2 = nn.BatchNorm2d(out_channel)
# 1*1卷积上升通道 [b,out_channel,h,w]==>[b,out_channel*expansion,h,w]
self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
kernel_size=1, stride=1, padding=0, bias=False)
# BN层,对out_channel*expansion标准化
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
# 记录是否需要下采样, 下采样就是第一个卷积层的步长=2,输入和输出的图像的尺寸不一致
self.downsample = downsample
# 前向传播
def forward(self, x):
# 残差边
identity = x
# 如果第一个卷积层stride=2下采样了,那么残差边也需要下采样
if self.downsample is not None:
# 下采样方法
identity = self.downsample(x)
# 主干部分
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
# 残差连接
x = x + identity
# relu激活
x = self.relu(x)
return x # 输出残差单元的结果
1.2 构建网络
我们已经成功构建完单个残差单元的类,而残差结构就是由多个残差单元堆叠而来的,ResnNet50 中有 4 组残差结构,每个残差结构分别堆叠了 3,4,6,3 个残差单元,如下图所示。
第一个残差结构中