【图像分类案例】(8) ResNet50 鸟类图像4分类,附Pytorch完整代码

本文介绍了如何使用Pytorch搭建ResNet50模型,通过迁移学习进行鸟类图像的四分类任务。详细讲解了模型构建、训练过程和预测阶段,包括残差块的构建、网络结构、数据预处理、模型训练与验证,以及预测时的处理步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

大家好,今天和大家分享一些如何使用 Pytorch 搭建 ResNet50 卷积神经网络模型,并使用迁移学习的思想训练网络,完成鸟类图片的预测。

ResNet 的原理TensorFlow2实现方式可以看我之前的两篇博文,这里就不详细说明原理了。

ResNet18、34:https://blog.csdn.net/dgvv4/article/details/122396424

ResNet50:https://blog.csdn.net/dgvv4/article/details/121878494


1. 模型构建

首先导入网络构建过程中所有需要用到的工具包,本小节的所有代码写在 ResNet.py 文件中

import torch
from torch import nn
from torchstat import stat  # 查看网络参数
from torchsummary import summary  # 查看网络结构

1.1 构建单个残差块

残差单元的结构如下图所示,一种是基本模块,即输入特征图的尺寸和输出特征层的尺寸相同,两个特征图可以直接叠加;一种是下采样模块,即主干部分对输入特征图使用 stride=2 的下采样卷积,使得输入特征图的尺寸变成原来的一半,那么残差边部分也需要对输入特征图进行下采样操作,使得输入特征图经过残差边处理后的 shape 能和主干部分处理后的特征图 shape 相同,从而能够将残差边输出和主干输出直接叠加。

以下图基本残差块为例,先对输入图像使用 1*1 卷积下降通道数,在低维空间下使用 3*3 卷积提取特征,然后再使用 1*1 卷积上升通道数,残差连接输入和输出,将叠加后的结果进过 relu 激活函数。

代码如下:

# -------------------------------------------- #
#(1)残差单元
# x--> 卷积 --> bn --> relu --> 卷积 --> bn --> 输出 
# |---------------Identity(短接)----------------|
'''
in_channel   输入特征图的通道数 
out_channel  第一次卷积输出特征图的通道数
stride=1     卷积块中3*3卷积的步长
downsample   是否下采样
'''
# -------------------------------------------- #
class Bottleneck(nn.Module):
    # 最后一个1*1卷积下降的通道数
    expansion = 4
    
    # 初始化
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        # 继承父类初始化方法
        super(Bottleneck, self).__init__()
        
        # 属性分配
        # 1*1卷积下降通道,padding='same',若stride=1,则[b,in_channel,h,w]==>[b,out_channel,h,w]
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, padding=0, bias=False)

        # BN层是计算特征图在每个channel上面的均值和方差,需要给出输出通道数
        self.bn1 = nn.BatchNorm2d(out_channel)
        
        # relu激活, inplace=True节约内存
        self.relu = nn.ReLU(inplace=True)
        
        # 3*3卷积提取特征,[b,out_channel,h,w]==>[b,out_channel,h,w]
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        
        # BN层, 有bn层就不需要bias偏置
        self.bn2 = nn.BatchNorm2d(out_channel)
        
        # 1*1卷积上升通道 [b,out_channel,h,w]==>[b,out_channel*expansion,h,w]
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, padding=0, bias=False)
        
        # BN层,对out_channel*expansion标准化
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        
        # 记录是否需要下采样, 下采样就是第一个卷积层的步长=2,输入和输出的图像的尺寸不一致
        self.downsample = downsample
    
    # 前向传播
    def forward(self, x):
        
        # 残差边
        identity = x
        
        # 如果第一个卷积层stride=2下采样了,那么残差边也需要下采样
        if self.downsample is not None:
            # 下采样方法
            identity = self.downsample(x)
        
        # 主干部分
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        
        # 残差连接
        x = x + identity
        # relu激活
        x = self.relu(x)
        
        return x  # 输出残差单元的结果

1.2 构建网络

我们已经成功构建完单个残差单元的类,而残差结构就是由多个残差单元堆叠而来的,ResnNet50 中有 4 组残差结构,每个残差结构分别堆叠了 3,4,6,3 个残差单元,如下图所示。

第一个残差结构中

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值