一、了解模型结构
所有pytorch模型(包含YOLOv8/v5)均可通过加载打印查看模型结构
pt模型加载方法:
model = torch.load('model.pt')
YOLOv8模型加载方法:
model = YOLO('model.pt')
import torch.nn as nn
class YourModel(nn.Module):
def __init__(self):
super(YourModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.conv2(self.relu(self.conv1(x))))
return x
# 实例化模型
model = YourModel()
print(model)
# 打印输出
# YourModel(
# (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (relu): ReLU()
# (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
print(model.conv1)
# 打印输出
# Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
二、如何可视化特征图
继续以上面的YourModel为例
import torch
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
import os
import numpy as np
featuer = []
def hook_fn(module, input, output):
featuer.append(output)
print("featuer: ", featuer)
model = YourModel()
# 通过之前打印的模型信息,选择你要可视化的层
layer = model.conv1
handle = layer.register_forward_hook(hook_fn)
# 自定义一个输入,或者加载一张图片
x = torch.random([1, 3, 640, 640])
# 模型推理
output = model(x)
handle.remove()
# 特征图可视化
batch, channels, height, width = featuer.shape
feature_blocks = torch.chunk(featuer[0].cpu(), channels, dim=0)
for i in range(channels):
fig, ax = plt.subplots()
feature_map_np = feature_blocks[i].squeeze().detach().numpy()
ax.imshow(feature_map_np)
ax.axis('off')
file_name = Path('特征图保存文件夹save_dir') / f'feature_{i}.png'
plt.savefig(file_name, dpi=600, bbox_inches='tight', pad_inches=0)
plt.close()
加载图片代码:
import cv2
def read_img(img):
image = cv2.imread(img)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.transpose((2, 0, 1))
image_tensor = torch.from_numpy(image)
image_tensor = image_tensor.float() / 255.0
image_tensor = image_tensor.unsqueeze(0)
return image_tensor
# 使用示例,用这个替换之前的x=torch.random([1,3,640,640])即可
x = read_img('a.jpg')
这是一个最基础的示例,它只能可视化一层,要想实现多层、或者复杂模型的可视化,还是比较复杂的
三、任意模型,任意层可视化实现
我个人已经实现兼容所有Pytorch模型可视化的代码,不依赖任何项目文件,单独一个py文件,可以一次可视化多层,会自动生成保存到不同文件夹
YOLOv8官方代码也有特征图可视化功能,但是他只能查看大层(如Conv、C2f),它没法看到最底层的nn.Conv2d,nn.SiLU这些,我的代码是实现了任意层(从模块整体到最底层的激活函数)都可以看,并且自动输出到不同文件夹内。
代码需要的私信,实现不易,不白嫖,感谢理解