在计算机视觉领域,卷积神经网络(CNN)的性能提升往往依赖于网络结构的创新。GoogleNet 通过 Inception 模块实现了高效的特征提取,而 SE(Squeeze-and-Excitation)注意力机制则通过增强关键特征通道进一步提升模型性能。本文将详解一个融合 SE 机制的 GoogleNet 模型实现,用于蝴蝶种类识别任务,并提供完整可运行的代码。
【计算机视觉板块的毕业设计项目可参考】
【结尾附上全部源码】
【如有疑问可以观看往期博客进行学习】
【定期分享学习成果,欢迎关注一起学习!】
【实现效果】
一.项目背景与核心技术
本项目旨在实现高精度的蝴蝶种类识别模型,主要结合了两种核心技术:
- GoogleNet 的 Inception 模块:通过并行多尺度卷积(1x1、3x3、5x5)和池化操作,在保持计算效率的同时提取丰富特征。
- SE 注意力机制:通过学习通道权重,增强重要特征通道的贡献,抑制无关信息,提升模型判别能力。
下面将分模块详解实现过程。
[结构]
二.数据集获取
[以下是网站,自行获取,也可以自己选择其他的数据集]
butterfly Classification Dataset by weather
https://universe.roboflow.com/weather-mxebd/butterfly-k1vda
三.核心模块实现
2.1 SE 注意力机制模块(SEBlock)
SE 模块通过 Squeeze(压缩空间信息)和 Excitation(学习通道权重)操作,让模型关注更重要的特征通道。
### 1. 注意力机制:SE模块(Squeeze-and-Excitation)
class SEBlock(nn.Module):
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
# 压缩操作:全局平均池化,将每个通道的空间信息压缩为一个值
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# 激励操作:通过全连接层学习通道权重
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False), # 降维,减少计算量
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False), # 升维,恢复通道数
nn.Sigmoid() # 输出0-1的权重
)
def forward(self, x):
b, c, _, _ = x.size() # b:批次大小, c:通道数
# 压缩:(b,c,h,w) -> (b,c,1,1) -> (b,c)
y = self.avg_pool(x).view(b, c)
# 激励:(b,c) -> (b,c) -> (b,c,1,1)
y = self.fc(y).view(b, c, 1, 1)
# 加权:每个通道乘以对应的权重
return x * y
核心逻辑:通过全局平均池化将空间信息压缩为通道统计量,再通过全连接层学习通道重要性权重,最终用权重对输入特征进行加权。
2.2 Inception 模块(融合 SE 机制)
Inception 模块是 GoogleNet 的核心,通过多分支并行结构提取多尺度特征,这里加入了 SE 机制进一步优化。
### 2. Inception模块(GoogleNet核心)
class Inception(nn.Module):
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj, use_se=True):
super(Inception, self).__init__()
# 1x1卷积分支:减少通道数,降低计算量
self.branch1 = nn.Sequential(
nn.Conv2d(in_channels, ch1x1, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ch1x1),
nn.ReLU(inplace=True)
)
# 3x3卷积分支(先1x1降维)
self.branch2 = nn.Sequential(
nn.Conv2d(in_channels, ch3x3red, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ch3x3red),
nn.ReLU(inplace=True),
nn.Conv2d(ch3x3red, ch3x3, kernel_size=3, stride=1, padding=1, bias=False), # 保持尺寸不变
nn.BatchNorm2d(ch3x3),
nn.ReLU(inplace=True)
)
# 5x5卷积分支(先1x1降维)
self.branch3 = nn.Sequential(
nn.Conv2d(in_channels, ch5x5red, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ch5x5red),
nn.ReLU(inplace=True),
nn.Conv2d(ch5x5red, ch5x5, kernel_size=5, stride=1, padding=2, bias=False), # 保持尺寸不变
nn.BatchNorm2d(ch5x5),
nn.ReLU(inplace=True)
)
# 池化分支(池化后1x1卷积调整通道)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1), # 保持尺寸不变
nn.Conv2d(in_channels, pool_proj, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(pool_proj),
nn.ReLU(inplace=True)
)
# 通道注意力模块
self.use_se = use_se
if self.use_se:
# 拼接后通道数 = 各分支通道数之和
self.se = SEBlock(ch1x1 + ch3x3 + ch5x5 + pool_proj)
def forward(self, x):
branch1 = self.branch1(x)
branch2 = self.branch2(x)
branch3 = self.branch3(x)
branch4 = self.branch4(x)
# 沿通道维度拼接各分支结果
x = torch.cat([branch1, branch2, branch3, branch4], 1)
# 应用SE注意力机制
if self.use_se:
x = self.se(x)
return x
核心逻辑:4 个并行分支分别提取不同尺度特征,通过 1x1 卷积降低计算成本,最后拼接特征并通过 SE 模块增强关键通道。
四.完整模型构建(GoogleNetSE)
将 Inception 模块与 SE 机制结合,构建完整的分类模型,包含辅助分类器(训练时抑制过拟合)。
### 3. 带注意力机制的GoogleNet模型
class GoogleNetSE(nn.Module):
def __init__(self, num_classes=75, use_se=True):
super(GoogleNetSE, self).__init__()
self.use_se = use_se
# 初始卷积层:缩小特征图尺寸,提取低级特征
self.pre_layers = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # 尺寸/2
nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(192),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 尺寸/2
)
# Inception模块序列(按GoogleNet结构堆叠)
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32, use_se)
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64, use_se)
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 尺寸/2
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64, use_se)
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64, use_se)
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64, use_se)
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64, use_se)
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128, use_se)
self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 尺寸/2
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128, use_se)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128, use_se)
# 全局平均池化:将特征图转为固定尺寸
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
# 主分类器
self.classifier = nn.Sequential(
nn.Dropout(0.4), # 防止过拟合
nn.Linear(1024 * 7 * 7, 1024), # 全连接层
nn.ReLU(inplace=True),
nn.Linear(1024, num_classes) # 输出类别数
)
# 辅助分类器(训练时使用,缓解梯度消失)
self.aux1 = self._make_aux_classifier(512, num_classes)
self.aux2 = self._make_aux_classifier(528, num_classes)
def _make_aux_classifier(self, in_channels, num_classes):
"""创建辅助分类器"""
return nn.Sequential(
nn.AvgPool2d(kernel_size=5, stride=3), # 缩小特征图
nn.Conv2d(in_channels, 128, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(128 * 4 * 4, 1024),
nn.ReLU(inplace=True),
nn.Dropout(0.7),
nn.Linear(1024, num_classes)
)
def forward(self, x):
x = self.pre_layers(x)
x = self.inception3a(x)
x = self.inception3b(x)
x = self.maxpool3(x)
x = self.inception4a(x)
aux1 = self.aux1(x) if self.training else None # 训练时才启用辅助分类器
x = self.inception4b(x)
x = self.inception4c(x)
x = self.inception4d(x)
aux2 = self.aux2(x) if self.training else None
x = self.inception4e(x)
x = self.maxpool4(x)
x = self.inception5a(x)
x = self.inception5b(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
if self.training:
return x, aux1, aux2 # 训练时返回主输出和两个辅助输出
else:
return x # 推理时只返回主输出
核心逻辑:模型分为初始卷积层、Inception 模块堆叠层、分类器三部分。训练时通过辅助分类器提供额外梯度,缓解深层网络梯度消失问题;SE 机制增强关键特征通道,提升分类精度。
五.数据加载与预处理
合理的数据增强可提升模型泛化能力,下面是兼容 PIL 和 Tensor 变换的数据加载函数:
### 4. 修复的数据加载函数(确保变换兼容性)
def build_data(data_dir, batch_size=8):
if not os.path.exists(data_dir):
raise FileNotFoundError(f"数据集目录不存在: {data_dir}")
required_subdirs = ['train', 'valid']
for subdir in required_subdirs:
subdir_path = os.path.join(data_dir, subdir)
if not os.path.exists(subdir_path):
raise FileNotFoundError(f"数据子目录不存在: {subdir_path}")
print(f"成功找到数据集目录结构: {data_dir}")
# 分离PIL图像变换和Tensor变换,避免兼容性问题
pil_transforms = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.7, 1.0), ratio=(0.8, 1.2)), # 随机裁剪
transforms.RandomHorizontalFlip(p=0.5), # 水平翻转
transforms.RandomVerticalFlip(p=0.3), # 垂直翻转
transforms.RandomRotation(30), # 随机旋转
transforms.RandomPerspective(distortion_scale=0.2, p=0.3) # 随机透视变换
])
# 转换为Tensor后再进行像素级变换
tensor_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.ColorJitter( # 颜色抖动
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.1
),
transforms.RandomGrayscale(p=0.2), # 随机转为灰度图
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.3), # 高斯模糊
transforms.RandomErasing(p=0.2, scale=(0.02, 0.25)), # 随机遮挡
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化(ImageNet均值)
])
# 组合变换
train_transform = transforms.Compose([
pil_transforms,
tensor_transforms
])
# 验证集转换(仅Resize和标准化,无增强)
val_test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集(使用ImageFolder,要求数据集按类别分文件夹)
train_dataset = datasets.ImageFolder(
root=os.path.join(data_dir, 'train'),
transform=train_transform
)
val_dataset = datasets.ImageFolder(
root=os.path.join(data_dir, 'valid'),
transform=val_test_transform
)
# 创建数据加载器
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True # 加速GPU传输
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=True
)
return train_loader, val_loader, train_dataset.classes
核心逻辑:训练集采用多种数据增强(裁剪、翻转、旋转等)扩展数据多样性;验证集仅做标准化和 Resize,确保评估准确性。分离 PIL 和 Tensor 变换避免数据类型冲突。
六.迁移学习与模型初始化
利用预训练权重加速收敛,冻结部分层降低过拟合风险:
### 新增:预训练模型初始化(迁移学习)
def init_pretrained_model(num_classes=75, pretrained_path=None, freeze_layers=True):
"""
初始化带预训练权重的模型,支持迁移学习
参数:
num_classes: 目标任务类别数(蝴蝶种类数)
pretrained_path: 预训练权重文件路径(如ImageNet预训练的GoogleNet权重)
freeze_layers: 是否冻结部分层(仅训练顶层)
"""
model = GoogleNetSE(num_classes=num_classes, use_se=True)
if pretrained_path and os.path.exists(pretrained_path):
print(f"加载预训练权重: {pretrained_path}")
# 加载预训练权重
pretrained_dict = torch.load(pretrained_path, map_location=torch.device('cpu'))
model_dict = model.state_dict()
# 过滤权重:仅加载特征提取部分(不加载分类器,避免覆盖目标任务的输出层)
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if k in model_dict and
not any(keyword in k for keyword in ['classifier', 'aux1', 'aux2'])
}
# 更新模型权重
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print(f"成功加载 {len(pretrained_dict)} 个预训练权重参数")
# 冻结部分层(迁移学习常用策略:仅训练顶层)
if freeze_layers:
print("冻结部分层,仅训练顶层...")
# 冻结初始卷积层和前2个Inception阶段的参数
freeze_modules = [
model.pre_layers,
model.inception3a, model.inception3b,
model.inception4a, model.inception4b
]
for module in freeze_modules:
for param in module.parameters():
param.requires_grad = False # 冻结参数,不参与训练
return model
核心逻辑:加载预训练权重时过滤分类器层(避免覆盖目标任务输出),冻结底层特征提取层(保留通用特征),仅训练顶层 Inception 和分类器,加速收敛并减少过拟合。
七.模型训练与评估
6.1 训练函数
实现模型训练循环,支持 TensorBoard 可视化和最佳模型保存:
### 5. 训练函数
def train_model(model, train_loader, val_loader, epochs=30, lr=0.0005, writer=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) # 带L2正则化
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min', patience=3, factor=0.5, verbose=True # 验证损失不下降时衰减学习率
)
best_val_acc = 0.0
for epoch in range(epochs):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs, aux1, aux2 = model(inputs)
# 主损失 + 0.3*辅助损失(辅助损失权重可调整)
loss_main = criterion(outputs, targets)
loss_aux1 = criterion(aux1, targets)
loss_aux2 = criterion(aux2, targets)
loss = loss_main + 0.3 * loss_aux1 + 0.3 * loss_aux2
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
# 清理内存,避免OOM
del inputs, targets, outputs, aux1, aux2, predicted, loss
torch.cuda.empty_cache()
train_acc = 100.0 * train_correct / train_total
avg_train_loss = train_loss / len(train_loader)
# 验证阶段
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad(): # 关闭梯度计算,加速验证
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
del inputs, targets, outputs, predicted, loss
torch.cuda.empty_cache()
val_acc = 100.0 * val_correct / val_total
val_loss_avg = val_loss / len(val_loader)
print(f'Epoch [{epoch + 1}/{epochs}] '
f'Train Loss: {avg_train_loss:.4f} '
f'Train Acc: {train_acc:.2f}% '
f'Val Loss: {val_loss_avg:.4f} '
f'Val Acc: {val_acc:.2f}%')
# TensorBoard记录指标
writer.add_scalars('Loss', {'train': avg_train_loss, 'val': val_loss_avg}, epoch)
writer.add_scalars('Accuracy', {'train': train_acc/100, 'val': val_acc/100}, epoch)
scheduler.step(val_loss_avg) # 基于验证损失调整学习率
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print(f'模型已保存 (验证准确率: {val_acc:.2f}%)')
print(f'最佳验证准确率: {best_val_acc:.2f}%')
return model
6.2 继续训练函数
支持从保存的模型继续训练,适合分阶段调优:
### 6. 继续训练函数
def retrain_model(model, train_loader, val_loader, pretrained_path, epochs=30, lr=0.0001, writer=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
if os.path.exists(pretrained_path):
model.load_state_dict(torch.load(pretrained_path, map_location=device))
print(f"成功加载预训练模型: {pretrained_path}")
else:
raise FileNotFoundError(f"预训练模型文件不存在: {pretrained_path}")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) # 学习率通常比首次训练小
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min', patience=3, factor=0.5, verbose=True
)
# 初始验证
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
del inputs, targets, outputs, predicted
torch.cuda.empty_cache()
best_val_acc = 100.0 * val_correct / val_total
print(f"预训练模型初始验证准确率: {best_val_acc:.2f}%")
# 继续训练(逻辑与train_model类似,此处省略重复代码)
# ...(完整代码见原文)
return model
八.模型测试与结果分析
7.1 模型测试
### 7. 测试函数
def test_model(pretrained_path='best_model.pth', test_dir='./data/test/'):
if not os.path.exists(pretrained_path):
raise FileNotFoundError(f"模型文件不存在: {pretrained_path}")
if not os.path.exists(test_dir):
raise FileNotFoundError(f"测试目录不存在: {test_dir}")
# 获取类别信息
_, _, classes = build_data('./data/bfly_data', batch_size=1)
num_classes = len(classes)
# 初始化模型
model = GoogleNetSE(num_classes=num_classes, use_se=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(pretrained_path, map_location=device))
model = model.to(device)
model.eval()
# 定义图像预处理
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 获取测试图片路径
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
test_images = []
for filename in os.listdir(test_dir):
if any(filename.lower().endswith(ext) for ext in image_extensions):
test_images.append(os.path.join(test_dir, filename))
if not test_images:
print(f"测试目录 {test_dir} 中未找到图片文件")
return
print(f"\n开始测试,共发现 {len(test_images)} 张图片...\n")
# 逐张图片测试
for img_path in test_images:
try:
img = Image.open(img_path).convert('RGB') # 转为RGB格式
img_tensor = test_transform(img).unsqueeze(0) # 增加批次维度
img_tensor = img_tensor.to(device)
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs, dim=1) # 转换为概率
max_prob, predicted_idx = torch.max(probabilities, 1)
max_prob = max_prob.item() * 100
predicted_class = classes[predicted_idx.item()]
print(f"图片: {os.path.basename(img_path)}")
print(f"预测类别: {predicted_class}")
print(f"置信度: {max_prob:.2f}%")
print("-" * 50)
except Exception as e:
print(f"处理图片 {os.path.basename(img_path)} 时出错: {str(e)}")
print("-" * 50)
7.2 批量验证并保存结果
### 8. 新增:验证并保存CSV结果函数
def validate_and_save_csv(model_path='best_model.pth', data_dir='./data/bfly_data', result_dir='./validation_results'):
"""在验证集上评估模型并将详细结果保存为CSV文件"""
os.makedirs(result_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_, val_loader, classes = build_data(data_dir, batch_size=16)
num_classes = len(classes)
# 初始化模型并加载权重
model = GoogleNetSE(num_classes=num_classes, use_se=True)
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()
# 收集结果
all_results = []
total_correct = 0
total_samples = 0
with torch.no_grad():
current_idx = 0
for inputs, labels in val_loader:
# 前向传播并计算概率
# ...(完整代码见原文)
# 保存为CSV
results_df = pd.DataFrame(all_results, columns=columns)
csv_path = os.path.join(result_dir, 'validation_results.csv')
results_df.to_csv(csv_path, index=False)
print(f"验证结果已保存至: {csv_path}")
return accuracy
九.训练控制与主程序
### 8. 训练控制函数
def run_training(mode="retrain"):
print(f"当前工作目录: {os.getcwd()}")
data_dir = './data/bfly_data'
if not os.path.exists(data_dir):
print(f"错误: 数据集路径不存在 - {data_dir}")
exit(1)
# 加载数据集
train_loader, val_loader, classes = build_data(data_dir, batch_size=12)
print(f'类别数量: {len(classes)}, 类别列表: {classes[:5]}...')
# 初始化TensorBoard
writer = SummaryWriter('./tensorboard')
print(f"TensorBoard日志路径: ./tensorboard")
# 可视化训练样本
# ...(完整代码见原文)
# 初始化模型
model = init_pretrained_model(
num_classes=len(classes),
pretrained_path='pretrained_googlenet.pth',
freeze_layers=True
)
# 模型结构可视化
dummy_input = torch.randn(1, 3, 224, 224).to(device)
writer.add_graph(model.to(device), dummy_input)
# 选择训练模式
if mode == "train":
model = train_model(model, train_loader, val_loader, epochs=30, lr=0.0005, writer=writer)
else:
model = retrain_model(
model, train_loader, val_loader,
pretrained_path='best_model.pth',
epochs=100, lr=0.0001, writer=writer
)
writer.close()
# 主程序入口
if __name__ == '__main__':
# run_training(mode="retrain") # 运行训练
test_model() # 运行测试
# validate_and_save_csv() # 保存验证结果
11.源码
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.nn import Module
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
from PIL import Image
from torchviz import make_dot
import numpy as np
# 设置GPU内存分配策略
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
### 1. 注意力机制:SE模块(Squeeze-and-Excitation)
class SEBlock(nn.Module):
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
### 2. Inception模块(GoogleNet核心)
class Inception(nn.Module):
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj, use_se=True):
super(Inception, self).__init__()
# 1x1卷积分支
self.branch1 = nn.Sequential(
nn.Conv2d(in_channels, ch1x1, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ch1x1),
nn.ReLU(inplace=True)
)
# 3x3卷积分支(先1x1降维)
self.branch2 = nn.Sequential(
nn.Conv2d(in_channels, ch3x3red, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ch3x3red),
nn.ReLU(inplace=True),
nn.Conv2d(ch3x3red, ch3x3, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(ch3x3),
nn.ReLU(inplace=True)
)
# 5x5卷积分支(先1x1降维)
self.branch3 = nn.Sequential(
nn.Conv2d(in_channels, ch5x5red, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ch5x5red),
nn.ReLU(inplace=True),
nn.Conv2d(ch5x5red, ch5x5, kernel_size=5, stride=1, padding=2, bias=False),
nn.BatchNorm2d(ch5x5),
nn.ReLU(inplace=True)
)
# 池化分支
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channels, pool_proj, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(pool_proj),
nn.ReLU(inplace=True)
)
# 通道注意力模块
self.use_se = use_se
if self.use_se:
self.se = SEBlock(ch1x1 + ch3x3 + ch5x5 + pool_proj)
def forward(self, x):
branch1 = self.branch1(x)
branch2 = self.branch2(x)
branch3 = self.branch3(x)
branch4 = self.branch4(x)
x = torch.cat([branch1, branch2, branch3, branch4], 1)
if self.use_se:
x = self.se(x)
return x
### 3. 带注意力机制的GoogleNet模型
class GoogleNetSE(nn.Module):
def __init__(self, num_classes=75, use_se=True):
super(GoogleNetSE, self).__init__()
self.use_se = use_se
# 初始卷积层
self.pre_layers = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(192),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# Inception模块序列
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32, use_se)
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64, use_se)
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64, use_se)
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64, use_se)
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64, use_se)
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64, use_se)
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128, use_se)
self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128, use_se)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128, use_se)
# 全局平均池化
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
# 分类器
self.classifier = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(1024 * 7 * 7, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, num_classes)
)
# 辅助分类器
self.aux1 = self._make_aux_classifier(512, num_classes)
self.aux2 = self._make_aux_classifier(528, num_classes)
def _make_aux_classifier(self, in_channels, num_classes):
return nn.Sequential(
nn.AvgPool2d(kernel_size=5, stride=3),
nn.Conv2d(in_channels, 128, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(128 * 4 * 4, 1024),
nn.ReLU(inplace=True),
nn.Dropout(0.7),
nn.Linear(1024, num_classes)
)
def forward(self, x):
x = self.pre_layers(x)
x = self.inception3a(x)
x = self.inception3b(x)
x = self.maxpool3(x)
x = self.inception4a(x)
aux1 = self.aux1(x) if self.training else None
x = self.inception4b(x)
x = self.inception4c(x)
x = self.inception4d(x)
aux2 = self.aux2(x) if self.training else None
x = self.inception4e(x)
x = self.maxpool4(x)
x = self.inception5a(x)
x = self.inception5b(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
if self.training:
return x, aux1, aux2
else:
return x
### 4. 修复的数据加载函数(确保变换兼容性)
def build_data(data_dir, batch_size=8):
if not os.path.exists(data_dir):
raise FileNotFoundError(f"数据集目录不存在: {data_dir}")
required_subdirs = ['train', 'valid']
for subdir in required_subdirs:
subdir_path = os.path.join(data_dir, subdir)
if not os.path.exists(subdir_path):
raise FileNotFoundError(f"数据子目录不存在: {subdir_path}")
print(f"成功找到数据集目录结构: {data_dir}")
# 分离PIL图像变换和Tensor变换,确保兼容性
pil_transforms = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.7, 1.0), ratio=(0.8, 1.2)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.3),
transforms.RandomRotation(30),
transforms.RandomPerspective(distortion_scale=0.2, p=0.3)
])
# 转换为Tensor后再进行像素级变换
tensor_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.ColorJitter(
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.1
),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.3),
transforms.RandomErasing(p=0.2, scale=(0.02, 0.25)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 组合变换
train_transform = transforms.Compose([
pil_transforms,
tensor_transforms
])
# 验证集转换
val_test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.ImageFolder(
root=os.path.join(data_dir, 'train'),
transform=train_transform
)
val_dataset = datasets.ImageFolder(
root=os.path.join(data_dir, 'valid'),
transform=val_test_transform
)
# 创建数据加载器
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=True
)
return train_loader, val_loader, train_dataset.classes
### 新增:预训练模型初始化(迁移学习)
def init_pretrained_model(num_classes=75, pretrained_path=None, freeze_layers=True):
"""
初始化带预训练权重的模型,支持迁移学习
参数:
num_classes: 目标任务类别数(蝴蝶种类数)
pretrained_path: 预训练权重文件路径(如ImageNet预训练的GoogleNet权重)
freeze_layers: 是否冻结部分层(仅训练顶层)
"""
model = GoogleNetSE(num_classes=num_classes, use_se=True)
if pretrained_path and os.path.exists(pretrained_path):
print(f"加载预训练权重: {pretrained_path}")
# 加载预训练权重
pretrained_dict = torch.load(pretrained_path, map_location=torch.device('cpu'))
model_dict = model.state_dict()
# 过滤权重:仅加载特征提取部分(不加载分类器,避免覆盖目标任务的输出层)
# 排除与分类器相关的层(如classifier、aux1、aux2)
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if k in model_dict and
not any(keyword in k for keyword in ['classifier', 'aux1', 'aux2'])
}
# 更新模型权重
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print(f"成功加载 {len(pretrained_dict)} 个预训练权重参数")
# 冻结部分层(迁移学习常用策略:仅训练顶层)
if freeze_layers:
print("冻结部分层,仅训练顶层...")
# 冻结初始卷积层和前2个Inception阶段的参数
# 可根据需要调整冻结的层数
freeze_modules = [
model.pre_layers,
model.inception3a, model.inception3b,
model.inception4a, model.inception4b
]
for module in freeze_modules:
for param in module.parameters():
param.requires_grad = False # 冻结参数,不参与训练
return model
### 5. 训练函数
def train_model(model, train_loader, val_loader, epochs=30, lr=0.0005, writer=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min', patience=3, factor=0.5, verbose=True
)
best_val_acc = 0.0
for epoch in range(epochs):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs, aux1, aux2 = model(inputs)
loss_main = criterion(outputs, targets)
loss_aux1 = criterion(aux1, targets)
loss_aux2 = criterion(aux2, targets)
loss = loss_main + 0.3 * loss_aux1 + 0.3 * loss_aux2
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
del inputs, targets, outputs, aux1, aux2, predicted, loss
torch.cuda.empty_cache()
train_acc = 100.0 * train_correct / train_total
avg_train_loss = train_loss / len(train_loader)
# 验证阶段
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
del inputs, targets, outputs, predicted, loss
torch.cuda.empty_cache()
val_acc = 100.0 * val_correct / val_total
val_loss_avg = val_loss / len(val_loader)
print(f'Epoch [{epoch + 1}/{epochs}] '
f'Train Loss: {avg_train_loss:.4f} '
f'Train Acc: {train_acc:.2f}% '
f'Val Loss: {val_loss_avg:.4f} '
f'Val Acc: {val_acc:.2f}%')
# TensorBoard记录指标
writer.add_scalars('Loss', {'train': avg_train_loss, 'val': val_loss_avg}, epoch)
writer.add_scalars('Accuracy', {'train': train_acc/100, 'val': val_acc/100}, epoch)
scheduler.step(val_loss_avg)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print(f'模型已保存 (验证准确率: {val_acc:.2f}%)')
print(f'最佳验证准确率: {best_val_acc:.2f}%')
return model
### 6. 继续训练函数
def retrain_model(model, train_loader, val_loader, pretrained_path, epochs=30, lr=0.0001, writer=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
if os.path.exists(pretrained_path):
model.load_state_dict(torch.load(pretrained_path, map_location=device))
print(f"成功加载预训练模型: {pretrained_path}")
else:
raise FileNotFoundError(f"预训练模型文件不存在: {pretrained_path}")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min', patience=3, factor=0.5, verbose=True
)
# 初始验证
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
del inputs, targets, outputs, predicted
torch.cuda.empty_cache()
best_val_acc = 100.0 * val_correct / val_total
print(f"预训练模型初始验证准确率: {best_val_acc:.2f}%")
# 继续训练
print(f"开始继续训练,共{epochs}轮...")
for epoch in range(epochs):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs, aux1, aux2 = model(inputs)
loss_main = criterion(outputs, targets)
loss_aux1 = criterion(aux1, targets)
loss_aux2 = criterion(aux2, targets)
loss = loss_main + 0.3 * loss_aux1 + 0.3 * loss_aux2
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
del inputs, targets, outputs, aux1, aux2, predicted, loss
torch.cuda.empty_cache()
train_acc = 100.0 * train_correct / train_total
avg_train_loss = train_loss / len(train_loader)
# 验证
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
del inputs, targets, outputs, predicted, loss
torch.cuda.empty_cache()
val_acc = 100.0 * val_correct / val_total
val_loss_avg = val_loss / len(val_loader)
print(f'继续训练 Epoch [{epoch + 1}/{epochs}] '
f'Train Loss: {avg_train_loss:.4f} '
f'Train Acc: {train_acc:.2f}% '
f'Val Loss: {val_loss_avg:.4f} '
f'Val Acc: {val_acc:.2f}%')
# TensorBoard记录指标
writer.add_scalars('Loss', {'train': avg_train_loss, 'val': val_loss_avg}, epoch)
writer.add_scalars('Accuracy', {'train': train_acc/100, 'val': val_acc/100}, epoch)
scheduler.step(val_loss_avg)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print(f'继续训练 - 模型已更新保存 (验证准确率: {val_acc:.2f}%)')
print(f'继续训练完成,最佳验证准确率: {best_val_acc:.2f}%')
return model
### 7. 测试函数
def test_model(pretrained_path='best_model.pth', test_dir='./data/test/'):
if not os.path.exists(pretrained_path):
raise FileNotFoundError(f"模型文件不存在: {pretrained_path}")
if not os.path.exists(test_dir):
raise FileNotFoundError(f"测试目录不存在: {test_dir}")
# 获取类别信息
_, _, classes = build_data('./data/bfly_data', batch_size=1)
num_classes = len(classes)
# 初始化模型
model = GoogleNetSE(num_classes=num_classes, use_se=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(pretrained_path, map_location=device))
model = model.to(device)
model.eval()
# 定义图像预处理
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 获取测试图片路径
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
test_images = []
for filename in os.listdir(test_dir):
if any(filename.lower().endswith(ext) for ext in image_extensions):
test_images.append(os.path.join(test_dir, filename))
if not test_images:
print(f"测试目录 {test_dir} 中未找到图片文件")
return
print(f"\n开始测试,共发现 {len(test_images)} 张图片...\n")
# 逐张图片测试
for img_path in test_images:
try:
img = Image.open(img_path).convert('RGB')
img_tensor = test_transform(img).unsqueeze(0)
img_tensor = img_tensor.to(device)
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs, dim=1)
max_prob, predicted_idx = torch.max(probabilities, 1)
max_prob = max_prob.item() * 100
predicted_class = classes[predicted_idx.item()]
print(f"图片: {os.path.basename(img_path)}")
print(f"预测类别: {predicted_class}")
print(f"置信度: {max_prob:.2f}%")
print("-" * 50)
except Exception as e:
print(f"处理图片 {os.path.basename(img_path)} 时出错: {str(e)}")
print("-" * 50)
### 8. 新增:验证并保存CSV结果函数
def validate_and_save_csv(model_path='best_model.pth', data_dir='./data/bfly_data', result_dir='./validation_results'):
"""
在验证集上评估模型并将详细结果保存为CSV文件
参数:
model_path: 模型权重文件路径
data_dir: 包含train/valid子目录的数据集根目录
result_dir: 结果保存目录
"""
# 确保结果目录存在
os.makedirs(result_dir, exist_ok=True)
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 检查模型文件是否存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 加载验证集数据和类别信息
_, val_loader, classes = build_data(data_dir, batch_size=16)
num_classes = len(classes)
print(f"加载验证集完成,共 {len(val_loader.dataset)} 个样本,{num_classes} 个类别")
# 获取验证集图片路径(ImageFolder的imgs属性包含路径和标签)
val_dataset = val_loader.dataset
image_paths = [img_path for img_path, _ in val_dataset.imgs]
# 初始化模型并加载权重
model = GoogleNetSE(num_classes=num_classes, use_se=True)
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval() # 切换到评估模式
# 初始化结果列表
all_results = []
total_correct = 0
total_samples = 0
# 验证循环
with torch.no_grad():
current_idx = 0 # 用于跟踪当前处理的样本索引
for inputs, labels in val_loader:
batch_size = inputs.size(0)
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
probabilities = F.softmax(outputs, dim=1) # 计算每个类别的概率
_, predicted = torch.max(outputs, 1) # 获取预测类别
# 计算准确率
correct = (predicted == labels).sum().item()
total_correct += correct
total_samples += batch_size
# 将结果转移到CPU并转换为numpy数组
probs_np = probabilities.cpu().numpy()
predicted_np = predicted.cpu().numpy()
labels_np = labels.cpu().numpy()
# 收集当前batch的所有样本结果
for i in range(batch_size):
sample_idx = current_idx + i
if sample_idx >= len(image_paths):
break # 防止索引超出范围
# 获取当前样本的信息
img_path = image_paths[sample_idx]
true_label = classes[labels_np[i]]
pred_label = classes[predicted_np[i]]
is_correct = (predicted_np[i] == labels_np[i])
# 构建一行结果:图片路径 + 所有类别概率 + 预测类别 + 真实类别 + 是否正确
result_row = [img_path]
result_row.extend(probs_np[i]) # 所有类别的概率
result_row.extend([pred_label, true_label, is_correct])
all_results.append(result_row)
current_idx += batch_size
print(f"已处理 {current_idx}/{len(val_dataset)} 个样本...", end='\r')
# 计算整体准确率
accuracy = total_correct / total_samples if total_samples > 0 else 0
print(f"\n验证集整体准确率: {accuracy:.4f} ({total_correct}/{total_samples})")
# 构建CSV列名
columns = ['image_path'] # 图片路径
columns.extend(classes) # 每个类别的概率列
columns.extend(['predicted_class', 'true_class', 'is_correct']) # 预测信息列
# 转换为DataFrame并保存为CSV
results_df = pd.DataFrame(all_results, columns=columns)
csv_path = os.path.join(result_dir, 'validation_results.csv')
results_df.to_csv(csv_path, index=False)
print(f"验证结果已保存至: {csv_path}")
return accuracy
### 8. 训练控制函数
def run_training(mode="retrain"):
print(f"当前工作目录: {os.getcwd()}")
data_dir = './data/bfly_data'
print(f"正在检查数据集路径: {data_dir}")
if not os.path.exists(data_dir):
print(f"错误: 数据集路径不存在 - {data_dir}")
exit(1)
try:
print("正在加载数据集...")
train_loader, val_loader, classes = build_data(data_dir, batch_size=12)
print(f'类别数量: {len(classes)}, 类别列表: {classes[:5]}...')
except Exception as e:
print(f"加载数据时出错: {e}")
exit(1)
# 初始化TensorBoard
writer = SummaryWriter('./tensorboard')
print(f"TensorBoard日志路径: ./tensorboard")
# 可视化训练样本
try:
dataiter = iter(train_loader)
inputs, _ = next(dataiter)
# 反归一化用于可视化
inverse_norm = transforms.Normalize(
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
std=[1/0.229, 1/0.224, 1/0.225]
)
vis_inputs = inverse_norm(inputs)
vis_inputs = torch.clamp(vis_inputs, 0, 1)
img_grid = utils.make_grid(vis_inputs[:8], nrow=4, padding=2)
writer.add_image('Training Samples (Augmented)', img_grid, 0)
except Exception as e:
print(f"可视化训练样本失败: {e}")
# 初始化模型
model = init_pretrained_model(
num_classes=len(classes),
pretrained_path='pretrained_googlenet.pth', # 预训练权重路径
freeze_layers=True # 冻结部分层
)
total_params = sum(p.numel() for p in model.parameters())
print(f"模型参数数量: {total_params:,}")
# 模型结构可视化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dummy_input = torch.randn(1, 3, 224, 224).to(device)
writer.add_graph(model.to(device), dummy_input)
# 选择训练模式
train_mode = mode
if train_mode == "train":
print('开始首次训练模型...')
model = train_model(model, train_loader, val_loader, epochs=30, lr=0.001, writer=writer)
print('首次训练最佳模型已保存为: best_model.pth')
else:
print('开始继续训练模型...')
model = retrain_model(
model,
train_loader,
val_loader,
pretrained_path='best_model.pth',
epochs=50,
lr=0.0001,
writer=writer
)
print('继续训练最佳模型已保存为: best_model.pth')
# 关闭TensorBoard写入器
writer.close()
# 主程序入口
if __name__ == '__main__':
# # 先运行训练(模式为train或retrain)
run_training(mode="train")
# # 训练完成后运行测试
test_model()
# 运行验证保存CSV
# validate_and_save_csv()