【大作业收官-61】使用timm训练自己的分类模型-花卉分类
前言
环境配置和软件安装:Python项目配置前的准备工作-CSDN博客
从左向右依次是图像分类,目标检测,语义分割和实例分割。
图像分类是指为输入图像分配类别标签。自 2012 年采用深度卷积网络方法设计的 AlexNet 夺得 ImageNet 竞赛冠军后,图像分类开始全面采用深度卷积网络。2015 年,微软提出的 ResNet 采用残差思想,将输入中的一部分数据不经过神经网络而直接进入到输出中,解决了反向传播时的梯度弥散问题,从而使得网络深度达到 152 层,将错误率降低到 3.57%,远低于 5.1%的人眼识别错误率,夺得了ImageNet 大赛的冠军。
目标检测指用框标出物体的位置并给出物体的类别。2013 年加州大学伯克利分校的 Ross B. Girshick 提出 RCNN 算法之后,基于卷积神经网络的目标检测成为主流。之后的检测算法主要分为两类,一是基于区域建议的目标检测算法,通过提取候选区域,对相应区域进行以深度学习方法为主的分类,如 RCNN、Fast-RCNN、Faster-RCNN、SPP-net 和 Mask R-CNN 等系列方法。二是基于回归的目标检测算法,如 YOLO、SSD 和 DenseBox 等。
图像分割指将图像细分为多个图像子区域。2015 年开始,以全卷积神经网络(FCN)为代表的一系列基于卷积神经网络的语义分割方法相继提出,不断提高图像语义分割精度,成为目前主流的图像语义分割方法。实例分割则是实例级别的语义分割。
基础知识
我们本次讲解的是图像分类问题。图像分类指的是给定一张图像,来判断这个图像是什么类别。传统的方法有支持向量机一类的算法,现在普遍大家使用的比较多的是深度学习的方法。目前深度学习中使用的比较多的网络结构一种是卷积神经网络的结构,另外一种则是transformer的结构。
卷积神经网络
传统的神经网络在处理图像时存在参数过多、学习效率低的问题。而CNN则通过局部感受野、共享权重等机制有效降低了参数数量,并提升了在图像数据上的表现。
一个典型的CNN包括以下几种层:
① 卷积层(Convolutional Layer)
- 用多个卷积核(filter)扫描输入图像
- 提取局部特征,如边缘、角点、纹理等
- 输出称为“特征图”(feature map)
② 激活函数层(ReLU)
- 非线性处理
- 常用的如 ReLU(Rectified Linear Unit):
f(x) = max(0, x)
③ 池化层(Pooling Layer)
- 降采样特征图,减少计算量
- 常用的是最大池化(Max Pooling)
④ 全连接层(Fully Connected Layer)
- 类似传统神经网络的结构
- 综合所有局部特征进行分类
⑤ 输出层(Softmax)
- 输出概率分布,预测图像属于每个类别的概率
transformer结构
Transformer 完全基于 Self-Attention(自注意力)机制,其结构由两个主要部分组成:
- 编码器(Encoder):输入处理器
- 解码器(Decoder):输出生成器(如用于翻译、文本生成)
1. 编码器结构(Encoder)
由多个相同的模块堆叠,每个模块包括:
① 多头自注意力机制(Multi-Head Self Attention)
-
通过多个“注意力头”并行计算序列中各个位置之间的关系
-
用于建模单词与其他单词的依赖关系
-
核心计算:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
② 前馈神经网络(Feed Forward Network)
- 每个位置的向量独立通过一个小型 MLP(两层线性层 + 激活)
- 提升模型的非线性表达能力
③ 残差连接和 LayerNorm
- 防止深层网络的梯度消失,提升训练稳定性
2. 解码器结构(Decoder)
与编码器类似,但每个模块中还多了一步:
- Masked Multi-Head Attention:防止看到未来的信息(自回归建模)
- Encoder-Decoder Attention:允许解码器访问编码器输出的信息
3. 位置编码(Positional Encoding)
由于 Transformer 不使用循环结构,所以它对输入序列没有天然的“顺序感知能力”。
为了解决这个问题,加入了 位置编码(Positional Encoding),通常是正余弦函数生成的向量,加入到输入中,让模型感知单词之间的相对或绝对位置。
项目实战
环境配置和软件安装:Python项目配置前的准备工作-CSDN博客
环境安装
国内镜像更新
清华的配置如下。
conda config --remove-key channels
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
conda config --set show_channel_urls yes
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
北外的配置如下。
conda config --remove-key channels
conda config --add channels https://mirrors.ustc.edu.cn/anaconda/pkgs/main/
conda config --add channels https://mirrors.ustc.edu.cn/anaconda/pkgs/free/
conda config --add channels https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch/
conda config --set show_channel_urls yes
pip config set global.index-url https://mirrors.ustc.edu.cn/pypi/web/simple
环境创建和激活
创建3.8的环境
conda create -n project python==3.8.5
y
conda activate project
PyTorch安装
安装pytorch
conda install pytorch==1.8.0 torchvision torchaudio cudatoolkit=10.2 # 注意这条命令指定Pytorch的版本和cuda的版本
conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 # 30系列以上显卡gpu版本pytorch安装指令
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
########################## pip 快速安装
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0
pip install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
之后在项目目录下执行pip install 安装其他的依赖库。
pip install -r requirements.txt
数据的准备
数据集可以来源于你自己日常生活的拍照,或者是项目提供,或者是你直接从网络中进行下载获取,这里的数据最好进行清洗,保证你的数据集的名称中不包含中文以及你的数据集中的图片的通道数目为3,一方面是为了防止opecv在读取的过程中不会出现错误,另一方面是可以保证你的模型模型在训练的过程中不会因为通道数量的不匹配导致奇怪的中断。你所搜集的数据需要按照他们的类别放在不同的目录下面,比如玫瑰花的目录下面就需要全部是玫瑰花的图像,向日葵的目录下面则需要全部是向日葵的图像这样。
如果你的数据集中包含了训练集和测试集,则你不需要进行这一步,如果你的数据集中是一整个,则需要进行这一步。
在正式训练之前,你还需要准备好训练集、验证集和测试集,一般是需要将所有的数据按照6:2:2的比例进行划分。这里也为大家准备了data_split.py
的脚本帮助大家做数据集的划分。划分之前,请先按照类别将对应的图片放在对应的子目录下,如下图所示:
然后使用data_split.py
脚本即可,这里你只需要修改原始数据集的路径,直接运行就可以,之后将会产生一个后缀为split的目录,里面是按照6:2:2比例划分之后的数据集。
如果你想要修改划分的比例,可以在这个位置进行修改,保证他们的和等于1即可。
记住这里的路径,我们将会在后续的训练和测试中经常用到。
模型的训练
为了保证大家在之后看视频的过程中也能够复现出来,我们提前对每个库的版本进行了固定,我们本次使用的pytorch的模型库为timm。timm
(PyTorch Image Models)是一个由 Ross Wightman 开发的高质量 PyTorch 图像模型库,广泛用于图像分类、特征提取和模型迁移学习任务。它包含了大量经典与最新的预训练模型,如 ResNet、EfficientNet、MobileNet、RegNet、Vision Transformer(ViT)、Swin Transformer、ConvNeXt、DeiT、CoAtNet 等,几乎涵盖了主流论文中的架构。timm
提供统一且简洁的接口,使得模型加载和使用变得非常高效。通过 timm.create_model()
函数,可以快速实例化任何模型并选择是否加载预训练权重。模型的输入通道数、输出类别数等参数也支持灵活配置,方便用户自定义任务。除了模型本身,timm
还集成了多种实用工具,比如优化器、学习率调度器、训练脚本模板、混合精度训练支持、数据增强策略(如 RandAugment 和 AutoAugment)等,适合进行快速实验和大规模训练。该库对导出模型、迁移学习以及与 torchvision 数据加载器的结合使用也非常友好,并支持高效部署(如 ONNX 导出)。
注意,模型下载使用到了hugging face,请在执行的代码前面添加这个指令,保证你可以下载下来
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com
训练之前,大家需要设置一些基本信息,标着todo字样的都是大家需要进行修改的,这里我们以resnet50模型的训练为例。
一开始执行之前会有一个会需要下载预训练模型到指定目录,由于众所周知的原因,大家需要提前先把模型下载下来放置到这个目录,这个大家自行探索。
右键直接运行train.py
就可以开始训练模型,代码首先会输出模型的基本信息(模型有几个卷积层、池化层、全连接层构成)和运行的记录,如下图所示:
运行结束之后,训练好的模型将会保存在一开始你指定的目录下,保存的逻辑是:如果当前轮的准确率比目前最好的准确率高就会保存。同时,模型训练过程中acc和loss的变化曲线也会保存在指定的保存目录下,如下所示。
上面的图表示准确率的变化曲线,下面的图表示loss的变化曲线,其中蓝色的为为训练集,橘黄色的为验证集,从图中可以看出,随机训练过程的进行,模型在慢慢收敛直到稳定。
模型的验证
test.py
是模型测试的主程序,在进行模型测试之前,你需要对标注了todo文字所在行的内容进行修改,如下图所示,并且你的测试一定是在训练之后进行的。
修改完毕之后,右键执行test.py
就可以对你指定的模型在测试集上进行测试,测试结束,将会输出模型在数据集上的F1、Recall和ACC等指标,并且会生成由每类分别的准确率构成的热力图保存在record目录下,如下图所示。
测试脚本的开发
为了方便大家对任意的图片进行预测,我这里封装了两个函数,predict_batch
函数负责对一个文件夹下的所有图片进行预测,predict_single
函数负责对单张图片进行预测。通过这两个函数进行扩展,你可以完成对视频的实时预测,或者是通过HTTP接口应用在你网站上的后端程序等,具体用途需要靠大家发挥想象力了。
模型预测的代码是predict.py
文件,该文件中包含了两个函数,分别是文件夹预测和单张图片预测的函数,在实际运行的时候,你需要在main函数中指定具体要执行的是哪个函数。执行之前,需要设置一些基本参数,需要设置的参数我用todo进行了标记,大家按照标记进行修改就可以。
对文件夹下的图片进行预测
对文件夹的图片进行预测的时候,需要传入三个参数,分别是模型地址、需要预测的文件夹的地址、预测结果保存的地址(模型会按照预测结果放在对应的类别文件夹中)
比如对于下面的预测目录D:\upppppppppp\cls\cls_torch_tem\images\test_imgs\mini
,我希望结果保存在目录D:\upppppppppp\cls\cls_torch_tem\images\test_imgs\mini_result
下。
打开mini_result文件夹,可以看到每个图片都按照预测结果放在了不同的文件夹下。
对单张图片进行预测
比如我要对这样的一张图片进行预测,首先请标注这张图片的路径,后面将会使用到。
接着将图片路径写入predict_single
函数,右键直接运行即可,运行效果如下图所示。
图形化界面构建
图形化界面的构建依然通过pyqt5来编写,主要功能还是上传图片和对模型进行推理。
启动界面之前需要设置几个关键的参数,如下图所示:
之前有小伙伴比较好奇上传和推理的逻辑在哪,这部分主要是通过pyqt中信号与槽的机制来进行实现的,比如对于推理按钮,我们首先生成一个button,然后通过clicked.connect和具体事件绑定起来即可实现调用模型推理的功能。
之后,直接启动window.py即可运行。
模型其他信息查看
模型结构查看
一般在论文中,为了让我们的网络看起来方便读一些,我们需要查看网络具体的模型结构,这个时候需要借助netron这个软件,在进行查看之后,需要将我们的模型转换为onnx格式的模型,onnx格式的模型也是后续实际在windows上做模型部署用的形式,这个我们会单独开一个章节来讲。
比如以我们训练好的resnet50模型为例,代码在‘工具类-onnx模型导出.py’,和之前一样,还是需要修改几个固定的参数之后直接右键执行即可,即可得到转换为了onnx格式的模型。
然后将模型拖动到netron软件中即可查看具体的网络结构。
模型参数量查看
另外,有的时候我们需要查看模型的参数量和计算量,这里的计算量一般用flops来进行表示,但是实际上这里的计算量和fps不是严格意义上相关的,有的时候不一定的是flops低,模型的推理速度就快。模型实际的运行快慢还是需要看fps的,这部分的代码我没有写(一个简单的思路:就是算推理一个文件夹的图片花费了多少s,然后用图片数量除以时间, 就可以算出来fps)。
计算参数量的代码在get_flops.py
文件,同样将todo的几个参数修改之后直接运行即可查看。