一、实验内容:
1. 课题选择与数据准备
图像分割:医学影像分割:DRIVE视网膜血管分割
数据集:
来源:Gitee镜像 DRIVE-dataset 1
结构:
训练集:20张视网膜图像 + 血管标注 + 视网膜掩膜
测试集:20张视网膜图像 + 血管标注 + 视网膜掩膜
预处理:
尺寸调整至128×128(降低CPU计算负担)
应用视网膜掩膜排除背景区域
归一化像素值至[0,1]
标注二值化(血管=1,背景=0)
2. 模型构建与训练
算法选择:U-Net
网络架构:
编码器路径(下采样):
4个下采样块(原U-Net为5个),减少计算量
每个块包含两个3×3卷积+BN+ReLU
最大池化降采样(2×2)6
解码器路径(上采样):
转置卷积上采样(2×2)
跳跃连接融合编码器特征
双卷积块特征融合
输出层:1×1卷积+Sigmoid激活,输出血管概率图6
创新点:
减少网络深度平衡精度与速度
参数量仅7.8M(原U-Net约31M)
训练策略:
损失函数:Dice+BCE混合损失
解决血管-背景类别不平衡问题
Dice系数优化分割重叠度
BCE提供稳定梯度5
优化器:RMSprop
学习率0.0001,权重衰减1e-5
Momentum=0.9加速收敛8
学习率调度:
ReduceLROnPlateau(耐心=3,因子=0.5)
验证损失停滞时自动降低学习率
正则化:
权重衰减(L2正则化)
批量归一化
CPU适配:
批量大小=2(内存优化)
图像尺寸128×128
轻量化网络结构
3. 性能评估与优化
评估指标:
主要用 Dice 系数评估分割效果(属于 IoU 衍生指标,公式为
2∣A∩B∣/(∣A∣+∣B∣),通过 损失值(Dice + BCE 混合损失) 辅助观察训练过程。
可视化分析:
训练曲线:左图 Loss 曲线,右图Dice 曲线
消融实验:
(1)模块对比:
对比 “有无跳跃连接(U-Net 核心特性)”:去掉跳跃连接后,观察 Dice 系数和分割细节,验证其对医学图像(需保留细节)的必要性。
替换下采样方式(如用步幅卷积替代 MaxPool):测试是否能缓解小数据集下的信息丢失问题。
(2)参数 / 策略调整:
改变批量大小(batch_size):当前 batch_size=4,尝试 2/8 等,观察 Loss 收敛速度和 Dice 稳定性。
调整学习率调度:比如换用 StepLR 对比 ReduceLROnPlateau,看对训练曲线波动的影响。
(3)数据相关:
数据增强(旋转、亮度变换等):验证小数据集下增强对模型泛化性(Dice 稳定性)的提升。
标签处理:尝试不同二值化阈值(当前用 label > 0),观察对分割结果的影响。
二、实验步骤:
数据准备
使用本地DRIVE视网膜血管数据集,数据的相对路径为:/data/DIVE
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # 获取当前脚本所在目录
DATA_ROOT = os.path.join(BASE_DIR, "data", "DRIVE")
print(f"数据集路径: {DATA_ROOT}")
数据清洗与标注
在__getitem__里通过 mask > 0 对标签进行筛选(label = label * (mask > 0) ),利用视网膜区域掩膜去除了掩膜外的无效标注
数据集划分
分别加载 training 目录(训练集)和 test 目录(测试集)的数据
对数据的处理代码如下:
class DRIVEDataset(Dataset):
def __init__(self, root_dir, train=True, transform=None):
"""
Args:
root_dir (string): 数据集根目录
train (bool): 是否为训练集
transform (callable, optional): 可选的变换操作
"""
self.root_dir = root_dir
self.transform = transform
self.sub_dir = "training" if train else "test"
# 图像和掩膜路径
self.image_dir = os.path.join(root_dir, self.sub_dir, "images")
self.mask_dir = os.path.join(root_dir, self.sub_dir, "mask")
self.label_dir = os.path.join(root_dir, self.sub_dir, "1st_manual")
# 获取文件列表
self.image_files = [f for f in os.listdir(self.image_dir) if f.endswith('.tif')]
print(f"找到 {len(self.image_files)} 个图像文件在 {self.image_dir}")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
# 读取图像
img_name = self.image_files[idx]
img_path = os.path.join(self.image_dir, img_name)
image = Image.open(img_path).convert('RGB')
# 读取血管标注 (标签) - 修改此处的命名逻辑
# 从文件名中提取编号部分,假设文件名格式为"31_training.tif"
# 提取出"31"作为编号
file_number = img_name.split('_')[0]
label_name = f"{file_number}_manual1.gif"
label_path = os.path.join(self.label_dir, label_name)
label = Image.open(label_path).convert('L') # 转为灰度图
# 读取视网膜区域掩膜
mask_name = img_name.replace('.tif', '_mask.gif')
mask_path = os.path.join(self.mask_dir, mask_name)
mask = Image.open(mask_path).convert('L')
# 转换为数组
image = np.array(image)
label = np.array(label)
mask = np.array(mask)
# 应用视网膜掩膜
label = label * (mask > 0)
# 归一化
image = image.astype(np.float32) / 255.0
label = (label > 0).astype(np.float32) # 二值化
# 转换为PyTorch张量
image = torch.tensor(image).permute(2, 0, 1) # [C, H, W]
label = torch.tensor(label).unsqueeze(0) # [1, H, W]
# 应用变换
if self.transform:
image = self.transform(image)
label = self.transform(label)
return image, label
模型构建
轻量化U-Net模型实现,代码如下:
class DoubleConv(nn.Module):
"""(卷积 => BN => ReLU) × 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNetLight(nn.Module):
"""轻量化U-Net (4层下采样)"""
def __init__(self, n_channels=3, n_classes=1):
super().__init__()
# 下采样路径 (编码器)
self.inc = DoubleConv(n_channels, 64)
self.down1 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(64, 128)
)
self.down2 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(128, 256)
)
self.down3 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(256, 512)
)
# 上采样路径 (解码器)
self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv1 = DoubleConv(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv2 = DoubleConv(256, 128)
self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.conv3 = DoubleConv(128, 64)
# 输出层
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
# 编码路径
x1 = self.inc(x) # [64, H, W]
x2 = self.down1(x1) # [128, H/2, W/2]
x3 = self.down2(x2) # [256, H/4, W/4]
x4 = self.down3(x3) # [512, H/8, W/8]
# 解码路径 (跳跃连接)
x = self.up1(x4) # [256, H/4, W/4]
x = torch.cat([x, x3], dim=1) # [512, H/4, W/4]
x = self.conv1(x) # [256, H/4, W/4]
x = self.up2(x) # [128, H/2, W/2]
x = torch.cat([x, x2], dim=1) # [256, H/2, W/2]
x = self.conv2(x) # [128, H/2, W/2]
x = self.up3(x) # [64, H, W]
x = torch.cat([x, x1], dim=1) # [128, H, W]
x = self.conv3(x) # [64, H, W]
# 输出
return self.outc(x)
训练与可视化验证
class DiceBCELoss(nn.Module):
"""混合损失函数: Dice Loss + BCE Loss (使用 BCEWithLogitsLoss)"""
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
self.bce_loss = nn.BCEWithLogitsLoss() # 使用带 logits 的 BCE 损失
def forward(self, preds, targets):
# 展平预测和标签
preds = preds.view(-1)
targets = targets.view(-1)
# 计算Dice系数
intersection = (torch.sigmoid(preds) * targets).sum() # 注意这里对 preds 应用 sigmoid
dice_loss = 1 - (2. * intersection + self.smooth) / \
(torch.sigmoid(preds).sum() + targets.sum() + self.smooth)
# 计算BCE损失 (使用 BCEWithLogitsLoss,不需要手动应用 sigmoid)
bce = self.bce_loss(preds, targets)
return dice_loss + bce
def visualize_predictions(model, dataloader, num_samples=3):
"""可视化分割结果"""
model.eval()
# 获取一批样本
images, masks = next(iter(dataloader))
batch_size = images.shape[0]
display_samples = min(batch_size, num_samples)
fig, axes = plt.subplots(display_samples, 3, figsize=(15, display_samples * 4))
images = images.to(device)
with torch.no_grad():
preds = model(images).cpu()
images = images.cpu()
masks = masks.cpu()
for i in range(display_samples):
# 原始图像
if display_samples == 1:
axes[0].imshow(images[i].permute(1, 2, 0))
axes[0].set_title("Input Image")
axes[0].axis('off')
# 真实血管标注
axes[1].imshow(masks[i].squeeze(), cmap='gray')
axes[1].set_title("Ground Truth")
axes[1].axis('off')
# 预测结果
axes[2].imshow(preds[i].squeeze() > 0.5, cmap='gray') # 二值化阈值0.5
axes[2].set_title("Prediction")
axes[2].axis('off')
else:
axes[i, 0].imshow(images[i].permute(1, 2, 0))
axes[i, 0].set_title("Input Image")
axes[i, 0].axis('off')
# 真实血管标注
axes[i, 1].imshow(masks[i].squeeze(), cmap='gray')
axes[i, 1].set_title("Ground Truth")
axes[i, 1].axis('off')
# 预测结果
axes[i, 2].imshow(preds[i].squeeze() > 0.5, cmap='gray') # 二值化阈值0.5
axes[i, 2].set_title("Prediction")
axes[i, 2].axis('off')
plt.tight_layout()
plt.savefig("retina_seg_results.png")
print("分割结果已保存: retina_seg_results.png")
plt.show()
def plot_metrics(train_losses, val_losses, val_dices):
"""绘制训练指标曲线"""
plt.figure(figsize=(15, 5))
# 损失曲线
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Training & Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Dice系数曲线
plt.subplot(1, 2, 2)
plt.plot(val_dices, label='Val Dice', color='green')
plt.title('Validation Dice Coefficient')
plt.xlabel('Epochs')
plt.ylabel('Dice')
plt.ylim(0, 1)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("training_metrics.png")
print("训练指标已保存: training_metrics.png")
plt.show()
训练函数
def train_model():
"""模型训练主函数"""
# 检查数据集是否存在
if not os.path.exists(DATA_ROOT):
print(f"错误: 数据集路径 {DATA_ROOT} 不存在!")
print("请确保数据集已下载并放置在正确位置")
return None
print(f"数据集验证: {DATA_ROOT}")
print(f"训练目录: {os.path.join(DATA_ROOT, 'training')}")
print(f"测试目录: {os.path.join(DATA_ROOT, 'test')}")
# 数据预处理
transform = transforms.Compose([
transforms.Resize((128, 128)), # 减小尺寸以适应CPU
])
# 创建数据集
train_dataset = DRIVEDataset(DATA_ROOT, train=True, transform=transform)
test_dataset = DRIVEDataset(DATA_ROOT, train=False, transform=transform)
# 检查数据集是否为空
if len(train_dataset) == 0 or len(test_dataset) == 0:
print("错误: 未找到数据集文件!")
print(f"训练集路径: {os.path.join(DATA_ROOT, 'training')}")
print(f"测试集路径: {os.path.join(DATA_ROOT, 'test')}")
print(f"训练集文件数: {len(train_dataset)}, 测试集文件数: {len(test_dataset)}")
return None
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
# 数据加载器 (小批量适应CPU)
batch_size = 4 # Windows CPU可处理的大小
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# 初始化模型
model = UNetLight(n_channels=3, n_classes=1).to(device)
print("模型初始化完成")
# 训练参数
num_epochs = 20
learning_rate = 0.0001
weight_decay = 1e-5
# 损失函数和优化器
criterion = DiceBCELoss()
optimizer = RMSprop(model.parameters(), lr=learning_rate,
weight_decay=weight_decay, momentum=0.9)
# 学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min', patience=3, factor=0.5, verbose=True
)
# 记录指标
train_losses = []
val_losses = []
val_dices = []
best_dice = 0.0
print(f"开始训练 (CPU环境)...")
print(f"训练样本: {len(train_dataset)}, 测试样本: {len(test_dataset)}")
print(f"批量大小: {batch_size}, 训练轮次: {num_epochs}")
for epoch in range(num_epochs):
start_time = time.time()
model.train()
epoch_train_loss = 0.0
# 训练阶段
for images, masks in train_loader:
images = images.to(device)
masks = masks.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, masks)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_train_loss += loss.item() * images.size(0)
# 计算平均训练损失
epoch_train_loss /= len(train_loader.dataset)
train_losses.append(epoch_train_loss)
# 验证阶段
model.eval()
val_loss = 0.0
total_dice = 0.0
with torch.no_grad():
for images, masks in test_loader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item() * images.size(0)
# 计算Dice系数
preds = (outputs > 0.5).float()
intersection = (preds * masks).sum()
dice = (2. * intersection) / (preds.sum() + masks.sum() + 1e-6)
total_dice += dice.item()
# 计算平均验证指标
val_loss /= len(test_loader.dataset)
avg_dice = total_dice / len(test_loader)
val_losses.append(val_loss)
val_dices.append(avg_dice)
scheduler.step(val_loss)
# 保存最佳模型
if avg_dice > best_dice:
best_dice = avg_dice
torch.save(model.state_dict(), "unet_retina_best.pth")
print(f"保存最佳模型,Dice系数: {best_dice:.4f}")
epoch_time = time.time() - start_time
print(f"Epoch [{epoch + 1}/{num_epochs}] - "
f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {val_loss:.4f}, "
f"Val Dice: {avg_dice:.4f}, Time: {epoch_time:.1f}s, "
f"LR: {optimizer.param_groups[0]['lr']:.2e}")
# 可视化训练指标
plot_metrics(train_losses, val_losses, val_dices)
# 加载最佳模型并可视化结果
model.load_state_dict(torch.load("unet_retina_best.pth", map_location=device,weights_only=True))
visualize_predictions(model, test_loader, num_samples=3)
# 最终评估
model.eval()
total_dice = 0.0
with torch.no_grad():
for images, masks in test_loader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
preds = (outputs > 0.5).float()
# 计算Dice系数
intersection = (preds * masks).sum()
dice = (2. * intersection) / (preds.sum() + masks.sum() + 1e-6)
total_dice += dice.item()
print(f"测试集平均Dice系数: {total_dice / len(test_loader):.4f}")
return model
完整代码
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import RMSprop
import matplotlib.pyplot as plt
from PIL import Image
import random
import time
# 设置随机种子确保可重复性
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# 设备配置(使用CPU)
device = torch.device('cpu')
print(f"使用设备: {device}")
# ---------------------
# 1. 数据准备 (使用本地DRIVE视网膜血管数据集)
# ---------------------
# 数据集路径 - 使用相对路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # 获取当前脚本所在目录
DATA_ROOT = os.path.join(BASE_DIR, "data", "DRIVE")
print(f"数据集路径: {DATA_ROOT}")
# 数据集类
class DRIVEDataset(Dataset):
def __init__(self, root_dir, train=True, transform=None):
"""
Args:
root_dir (string): 数据集根目录
train (bool): 是否为训练集
transform (callable, optional): 可选的变换操作
"""
self.root_dir = root_dir
self.transform = transform
self.sub_dir = "training" if train else "test"
# 图像和掩膜路径
self.image_dir = os.path.join(root_dir, self.sub_dir, "images")
self.mask_dir = os.path.join(root_dir, self.sub_dir, "mask")
self.label_dir = os.path.join(root_dir, self.sub_dir, "1st_manual")
# 获取文件列表
self.image_files = [f for f in os.listdir(self.image_dir) if f.endswith('.tif')]
print(f"找到 {len(self.image_files)} 个图像文件在 {self.image_dir}")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
# 读取图像
img_name = self.image_files[idx]
img_path = os.path.join(self.image_dir, img_name)
image = Image.open(img_path).convert('RGB')
# 读取血管标注 (标签) - 修改此处的命名逻辑
# 从文件名中提取编号部分,假设文件名格式为"31_training.tif"
# 提取出"31"作为编号
file_number = img_name.split('_')[0]
label_name = f"{file_number}_manual1.gif"
label_path = os.path.join(self.label_dir, label_name)
label = Image.open(label_path).convert('L') # 转为灰度图
# 读取视网膜区域掩膜
mask_name = img_name.replace('.tif', '_mask.gif')
mask_path = os.path.join(self.mask_dir, mask_name)
mask = Image.open(mask_path).convert('L')
# 转换为数组
image = np.array(image)
label = np.array(label)
mask = np.array(mask)
# 应用视网膜掩膜
label = label * (mask > 0)
# 归一化
image = image.astype(np.float32) / 255.0
label = (label > 0).astype(np.float32) # 二值化
# 转换为PyTorch张量
image = torch.tensor(image).permute(2, 0, 1) # [C, H, W]
label = torch.tensor(label).unsqueeze(0) # [1, H, W]
# 应用变换
if self.transform:
image = self.transform(image)
label = self.transform(label)
return image, label
# ---------------------
# 2. 轻量化U-Net模型实现 (CPU优化)
# ---------------------
class DoubleConv(nn.Module):
"""(卷积 => BN => ReLU) × 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNetLight(nn.Module):
"""轻量化U-Net (4层下采样)"""
def __init__(self, n_channels=3, n_classes=1):
super().__init__()
# 下采样路径 (编码器)
self.inc = DoubleConv(n_channels, 64)
self.down1 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(64, 128)
)
self.down2 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(128, 256)
)
self.down3 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(256, 512)
)
# 上采样路径 (解码器)
self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv1 = DoubleConv(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv2 = DoubleConv(256, 128)
self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.conv3 = DoubleConv(128, 64)
# 输出层
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
# 编码路径
x1 = self.inc(x) # [64, H, W]
x2 = self.down1(x1) # [128, H/2, W/2]
x3 = self.down2(x2) # [256, H/4, W/4]
x4 = self.down3(x3) # [512, H/8, W/8]
# 解码路径 (跳跃连接)
x = self.up1(x4) # [256, H/4, W/4]
x = torch.cat([x, x3], dim=1) # [512, H/4, W/4]
x = self.conv1(x) # [256, H/4, W/4]
x = self.up2(x) # [128, H/2, W/2]
x = torch.cat([x, x2], dim=1) # [256, H/2, W/2]
x = self.conv2(x) # [128, H/2, W/2]
x = self.up3(x) # [64, H, W]
x = torch.cat([x, x1], dim=1) # [128, H, W]
x = self.conv3(x) # [64, H, W]
# 输出
return self.outc(x)
# ---------------------
# 3. 训练策略与可视化工具
# ---------------------
class DiceBCELoss(nn.Module):
"""混合损失函数: Dice Loss + BCE Loss (使用 BCEWithLogitsLoss)"""
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
self.bce_loss = nn.BCEWithLogitsLoss() # 使用带 logits 的 BCE 损失
def forward(self, preds, targets):
# 展平预测和标签
preds = preds.view(-1)
targets = targets.view(-1)
# 计算Dice系数
intersection = (torch.sigmoid(preds) * targets).sum() # 注意这里对 preds 应用 sigmoid
dice_loss = 1 - (2. * intersection + self.smooth) / \
(torch.sigmoid(preds).sum() + targets.sum() + self.smooth)
# 计算BCE损失 (使用 BCEWithLogitsLoss,不需要手动应用 sigmoid)
bce = self.bce_loss(preds, targets)
return dice_loss + bce
def visualize_predictions(model, dataloader, num_samples=3):
"""可视化分割结果"""
model.eval()
# 获取一批样本
images, masks = next(iter(dataloader))
batch_size = images.shape[0]
display_samples = min(batch_size, num_samples)
fig, axes = plt.subplots(display_samples, 3, figsize=(15, display_samples * 4))
images = images.to(device)
with torch.no_grad():
preds = model(images).cpu()
images = images.cpu()
masks = masks.cpu()
for i in range(display_samples):
# 原始图像
if display_samples == 1:
axes[0].imshow(images[i].permute(1, 2, 0))
axes[0].set_title("Input Image")
axes[0].axis('off')
# 真实血管标注
axes[1].imshow(masks[i].squeeze(), cmap='gray')
axes[1].set_title("Ground Truth")
axes[1].axis('off')
# 预测结果
axes[2].imshow(preds[i].squeeze() > 0.5, cmap='gray') # 二值化阈值0.5
axes[2].set_title("Prediction")
axes[2].axis('off')
else:
axes[i, 0].imshow(images[i].permute(1, 2, 0))
axes[i, 0].set_title("Input Image")
axes[i, 0].axis('off')
# 真实血管标注
axes[i, 1].imshow(masks[i].squeeze(), cmap='gray')
axes[i, 1].set_title("Ground Truth")
axes[i, 1].axis('off')
# 预测结果
axes[i, 2].imshow(preds[i].squeeze() > 0.5, cmap='gray') # 二值化阈值0.5
axes[i, 2].set_title("Prediction")
axes[i, 2].axis('off')
plt.tight_layout()
plt.savefig("retina_seg_results.png")
print("分割结果已保存: retina_seg_results.png")
plt.show()
def plot_metrics(train_losses, val_losses, val_dices):
"""绘制训练指标曲线"""
plt.figure(figsize=(15, 5))
# 损失曲线
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Training & Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Dice系数曲线
plt.subplot(1, 2, 2)
plt.plot(val_dices, label='Val Dice', color='green')
plt.title('Validation Dice Coefficient')
plt.xlabel('Epochs')
plt.ylabel('Dice')
plt.ylim(0, 1)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("training_metrics.png")
print("训练指标已保存: training_metrics.png")
plt.show()
# ---------------------
# 4. 训练函数
# ---------------------
def train_model():
"""模型训练主函数"""
# 检查数据集是否存在
if not os.path.exists(DATA_ROOT):
print(f"错误: 数据集路径 {DATA_ROOT} 不存在!")
print("请确保数据集已下载并放置在正确位置")
return None
print(f"数据集验证: {DATA_ROOT}")
print(f"训练目录: {os.path.join(DATA_ROOT, 'training')}")
print(f"测试目录: {os.path.join(DATA_ROOT, 'test')}")
# 数据预处理
transform = transforms.Compose([
transforms.Resize((128, 128)), # 减小尺寸以适应CPU
])
# 创建数据集
train_dataset = DRIVEDataset(DATA_ROOT, train=True, transform=transform)
test_dataset = DRIVEDataset(DATA_ROOT, train=False, transform=transform)
# 检查数据集是否为空
if len(train_dataset) == 0 or len(test_dataset) == 0:
print("错误: 未找到数据集文件!")
print(f"训练集路径: {os.path.join(DATA_ROOT, 'training')}")
print(f"测试集路径: {os.path.join(DATA_ROOT, 'test')}")
print(f"训练集文件数: {len(train_dataset)}, 测试集文件数: {len(test_dataset)}")
return None
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
# 数据加载器 (小批量适应CPU)
batch_size = 4 # Windows CPU可处理的大小
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# 初始化模型
model = UNetLight(n_channels=3, n_classes=1).to(device)
print("模型初始化完成")
# 训练参数
num_epochs = 20
learning_rate = 0.0001
weight_decay = 1e-5
# 损失函数和优化器
criterion = DiceBCELoss()
optimizer = RMSprop(model.parameters(), lr=learning_rate,
weight_decay=weight_decay, momentum=0.9)
# 学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min', patience=3, factor=0.5, verbose=True
)
# 记录指标
train_losses = []
val_losses = []
val_dices = []
best_dice = 0.0
print(f"开始训练 (CPU环境)...")
print(f"训练样本: {len(train_dataset)}, 测试样本: {len(test_dataset)}")
print(f"批量大小: {batch_size}, 训练轮次: {num_epochs}")
for epoch in range(num_epochs):
start_time = time.time()
model.train()
epoch_train_loss = 0.0
# 训练阶段
for images, masks in train_loader:
images = images.to(device)
masks = masks.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, masks)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_train_loss += loss.item() * images.size(0)
# 计算平均训练损失
epoch_train_loss /= len(train_loader.dataset)
train_losses.append(epoch_train_loss)
# 验证阶段
model.eval()
val_loss = 0.0
total_dice = 0.0
with torch.no_grad():
for images, masks in test_loader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item() * images.size(0)
# 计算Dice系数
preds = (outputs > 0.5).float()
intersection = (preds * masks).sum()
dice = (2. * intersection) / (preds.sum() + masks.sum() + 1e-6)
total_dice += dice.item()
# 计算平均验证指标
val_loss /= len(test_loader.dataset)
avg_dice = total_dice / len(test_loader)
val_losses.append(val_loss)
val_dices.append(avg_dice)
scheduler.step(val_loss)
# 保存最佳模型
if avg_dice > best_dice:
best_dice = avg_dice
torch.save(model.state_dict(), "unet_retina_best.pth")
print(f"保存最佳模型,Dice系数: {best_dice:.4f}")
epoch_time = time.time() - start_time
print(f"Epoch [{epoch + 1}/{num_epochs}] - "
f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {val_loss:.4f}, "
f"Val Dice: {avg_dice:.4f}, Time: {epoch_time:.1f}s, "
f"LR: {optimizer.param_groups[0]['lr']:.2e}")
# 可视化训练指标
plot_metrics(train_losses, val_losses, val_dices)
# 加载最佳模型并可视化结果
model.load_state_dict(torch.load("unet_retina_best.pth", map_location=device,weights_only=True))
visualize_predictions(model, test_loader, num_samples=3)
# 最终评估
model.eval()
total_dice = 0.0
with torch.no_grad():
for images, masks in test_loader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
preds = (outputs > 0.5).float()
# 计算Dice系数
intersection = (preds * masks).sum()
dice = (2. * intersection) / (preds.sum() + masks.sum() + 1e-6)
total_dice += dice.item()
print(f"测试集平均Dice系数: {total_dice / len(test_loader):.4f}")
return model
# ---------------------
# 5. 主执行流程
# ---------------------
if __name__ == "__main__":
model = train_model()
三、实验结果
左图 Loss 曲线显示训练集损失逐步下降,验证集 Loss 有波动但整体趋势向好;右图 Dice 曲线随训练推进逐步上升,说明模型对血管分割的 “吻合度” 在提升,但曲线波动大,反映小数据集下泛化不稳定。
对比 Input Image、Ground Truth、Prediction 可知,模型能大致捕捉血管轮廓。