医学影像分割:DRIVE视网膜血管分割

一、实验内容:

1. 课题选择与数据准备

图像分割:医学影像分割:DRIVE视网膜血管分割

​​数据集:

 来源:Gitee镜像 DRIVE-dataset 1

结构:

训练集:20张视网膜图像 + 血管标注 + 视网膜掩膜

测试集:20张视网膜图像 + 血管标注 + 视网膜掩膜

预处理:

尺寸调整至128×128(降低CPU计算负担)

应用视网膜掩膜排除背景区域

归一化像素值至[0,1]

标注二值化(血管=1,背景=0

2. 模型构建与训练

​​算法选择​​:U-Net

​​网络架构​​:

编码器路径(下采样):

4个下采样块(原U-Net5个),减少计算量

每个块包含两个3×3卷积+BN+ReLU

最大池化降采样(2×26

解码器路径(上采样):

转置卷积上采样(2×2

跳跃连接融合编码器特征

双卷积块特征融合

输出层:1×1卷积+Sigmoid激活,输出血管概率图6

创新点:

减少网络深度平衡精度与速度

参数量仅7.8M(原U-Net31M

​​训练策略​​:

损失函数:Dice+BCE混合损失

解决血管-背景类别不平衡问题

Dice系数优化分割重叠度

BCE提供稳定梯度5

优化器:RMSprop

学习率0.0001,权重衰减1e-5

Momentum=0.9加速收敛8

学习率调度:

ReduceLROnPlateau(耐心=3,因子=0.5

验证损失停滞时自动降低学习率

正则化:

权重衰减(L2正则化)

批量归一化

CPU适配:

批量大小=2(内存优化)

图像尺寸128×128

轻量化网络结构

3. 性能评估与优化

​​评估指标​​:

主要用 Dice 系数评估分割效果(属于 IoU 衍生指标,公式为

2AB/(A+B),通过 损失值(Dice + BCE 混合损失) 辅助观察训练过程。

​​可视化分析​​:

训练曲线:左图 Loss 曲线,右图Dice 曲线

​​消融实验​​:

1)模块对比:

对比 “有无跳跃连接(U-Net 核心特性)”:去掉跳跃连接后,观察 Dice 系数和分割细节,验证其对医学图像(需保留细节)的必要性。

替换下采样方式(如用步幅卷积替代 MaxPool):测试是否能缓解小数据集下的信息丢失问题。

2)参数 / 策略调整:

改变批量大小(batch_size):当前 batch_size=4,尝试 2/8 等,观察 Loss 收敛速度和 Dice 稳定性。

调整学习率调度:比如换用 StepLR 对比 ReduceLROnPlateau,看对训练曲线波动的影响。

3)数据相关:

数据增强(旋转、亮度变换等):验证小数据集下增强对模型泛化性(Dice 稳定性)的提升。

标签处理:尝试不同二值化阈值(当前用 label > 0),观察对分割结果的影响。

二、实验步骤:

数据准备

使用本地DRIVE视网膜血管数据集,数据的相对路径为:/data/DIVE

BASE_DIR = os.path.dirname(os.path.abspath(__file__))  # 获取当前脚本所在目录
DATA_ROOT = os.path.join(BASE_DIR, "data", "DRIVE")
print(f"数据集路径: {DATA_ROOT}")

数据清洗与标注

__getitem__里通过 mask > 0 对标签进行筛选(label = label * (mask > 0) ),利用视网膜区域掩膜去除了掩膜外的无效标注

数据集划分

分别加载 training 目录(训练集)和 test 目录(测试集)的数据

对数据的处理代码如下:

class DRIVEDataset(Dataset):
    def __init__(self, root_dir, train=True, transform=None):
        """
        Args:
            root_dir (string): 数据集根目录
            train (bool): 是否为训练集
            transform (callable, optional): 可选的变换操作
        """
        self.root_dir = root_dir
        self.transform = transform
        self.sub_dir = "training" if train else "test"

        # 图像和掩膜路径
        self.image_dir = os.path.join(root_dir, self.sub_dir, "images")
        self.mask_dir = os.path.join(root_dir, self.sub_dir, "mask")
        self.label_dir = os.path.join(root_dir, self.sub_dir, "1st_manual")

        # 获取文件列表
        self.image_files = [f for f in os.listdir(self.image_dir) if f.endswith('.tif')]
        print(f"找到 {len(self.image_files)} 个图像文件在 {self.image_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # 读取图像
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        # 读取血管标注 (标签) - 修改此处的命名逻辑
        # 从文件名中提取编号部分,假设文件名格式为"31_training.tif"
        # 提取出"31"作为编号
        file_number = img_name.split('_')[0]
        label_name = f"{file_number}_manual1.gif"
        label_path = os.path.join(self.label_dir, label_name)
        label = Image.open(label_path).convert('L')  # 转为灰度图

        # 读取视网膜区域掩膜
        mask_name = img_name.replace('.tif', '_mask.gif')
        mask_path = os.path.join(self.mask_dir, mask_name)
        mask = Image.open(mask_path).convert('L')

        # 转换为数组
        image = np.array(image)
        label = np.array(label)
        mask = np.array(mask)

        # 应用视网膜掩膜
        label = label * (mask > 0)

        # 归一化
        image = image.astype(np.float32) / 255.0
        label = (label > 0).astype(np.float32)  # 二值化

        # 转换为PyTorch张量
        image = torch.tensor(image).permute(2, 0, 1)  # [C, H, W]
        label = torch.tensor(label).unsqueeze(0)  # [1, H, W]

        # 应用变换
        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        return image, label

模型构建​​

轻量化U-Net模型实现,代码如下:

class DoubleConv(nn.Module):
    """(卷积 => BN => ReLU) × 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class UNetLight(nn.Module):
    """轻量化U-Net (4层下采样)"""

    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        # 下采样路径 (编码器)
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(64, 128)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(128, 256)
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(256, 512)
        )

        # 上采样路径 (解码器)
        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(256, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(128, 64)

        # 输出层
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # 编码路径
        x1 = self.inc(x)  # [64, H, W]
        x2 = self.down1(x1)  # [128, H/2, W/2]
        x3 = self.down2(x2)  # [256, H/4, W/4]
        x4 = self.down3(x3)  # [512, H/8, W/8]

        # 解码路径 (跳跃连接)
        x = self.up1(x4)  # [256, H/4, W/4]
        x = torch.cat([x, x3], dim=1)  # [512, H/4, W/4]
        x = self.conv1(x)  # [256, H/4, W/4]

        x = self.up2(x)  # [128, H/2, W/2]
        x = torch.cat([x, x2], dim=1)  # [256, H/2, W/2]
        x = self.conv2(x)  # [128, H/2, W/2]

        x = self.up3(x)  # [64, H, W]
        x = torch.cat([x, x1], dim=1)  # [128, H, W]
        x = self.conv3(x)  # [64, H, W]

        # 输出
        return self.outc(x)

训练与可视化验证

class DiceBCELoss(nn.Module):
    """混合损失函数: Dice Loss + BCE Loss (使用 BCEWithLogitsLoss)"""

    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
        self.bce_loss = nn.BCEWithLogitsLoss()  # 使用带 logits 的 BCE 损失

    def forward(self, preds, targets):
        # 展平预测和标签
        preds = preds.view(-1)
        targets = targets.view(-1)

        # 计算Dice系数
        intersection = (torch.sigmoid(preds) * targets).sum()  # 注意这里对 preds 应用 sigmoid
        dice_loss = 1 - (2. * intersection + self.smooth) / \
                    (torch.sigmoid(preds).sum() + targets.sum() + self.smooth)

        # 计算BCE损失 (使用 BCEWithLogitsLoss,不需要手动应用 sigmoid)
        bce = self.bce_loss(preds, targets)

        return dice_loss + bce


def visualize_predictions(model, dataloader, num_samples=3):
    """可视化分割结果"""
    model.eval()

    # 获取一批样本
    images, masks = next(iter(dataloader))
    batch_size = images.shape[0]
    display_samples = min(batch_size, num_samples)

    fig, axes = plt.subplots(display_samples, 3, figsize=(15, display_samples * 4))

    images = images.to(device)

    with torch.no_grad():
        preds = model(images).cpu()

    images = images.cpu()
    masks = masks.cpu()

    for i in range(display_samples):
        # 原始图像
        if display_samples == 1:
            axes[0].imshow(images[i].permute(1, 2, 0))
            axes[0].set_title("Input Image")
            axes[0].axis('off')

            # 真实血管标注
            axes[1].imshow(masks[i].squeeze(), cmap='gray')
            axes[1].set_title("Ground Truth")
            axes[1].axis('off')

            # 预测结果
            axes[2].imshow(preds[i].squeeze() > 0.5, cmap='gray')  # 二值化阈值0.5
            axes[2].set_title("Prediction")
            axes[2].axis('off')
        else:
            axes[i, 0].imshow(images[i].permute(1, 2, 0))
            axes[i, 0].set_title("Input Image")
            axes[i, 0].axis('off')

            # 真实血管标注
            axes[i, 1].imshow(masks[i].squeeze(), cmap='gray')
            axes[i, 1].set_title("Ground Truth")
            axes[i, 1].axis('off')

            # 预测结果
            axes[i, 2].imshow(preds[i].squeeze() > 0.5, cmap='gray')  # 二值化阈值0.5
            axes[i, 2].set_title("Prediction")
            axes[i, 2].axis('off')

    plt.tight_layout()
    plt.savefig("retina_seg_results.png")
    print("分割结果已保存: retina_seg_results.png")
    plt.show()


def plot_metrics(train_losses, val_losses, val_dices):
    """绘制训练指标曲线"""
    plt.figure(figsize=(15, 5))

    # 损失曲线
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title('Training & Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Dice系数曲线
    plt.subplot(1, 2, 2)
    plt.plot(val_dices, label='Val Dice', color='green')
    plt.title('Validation Dice Coefficient')
    plt.xlabel('Epochs')
    plt.ylabel('Dice')
    plt.ylim(0, 1)
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig("training_metrics.png")
    print("训练指标已保存: training_metrics.png")
    plt.show()

​​训练函数

def train_model():
    """模型训练主函数"""
    # 检查数据集是否存在
    if not os.path.exists(DATA_ROOT):
        print(f"错误: 数据集路径 {DATA_ROOT} 不存在!")
        print("请确保数据集已下载并放置在正确位置")
        return None

    print(f"数据集验证: {DATA_ROOT}")
    print(f"训练目录: {os.path.join(DATA_ROOT, 'training')}")
    print(f"测试目录: {os.path.join(DATA_ROOT, 'test')}")

    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # 减小尺寸以适应CPU
    ])

    # 创建数据集
    train_dataset = DRIVEDataset(DATA_ROOT, train=True, transform=transform)
    test_dataset = DRIVEDataset(DATA_ROOT, train=False, transform=transform)

    # 检查数据集是否为空
    if len(train_dataset) == 0 or len(test_dataset) == 0:
        print("错误: 未找到数据集文件!")
        print(f"训练集路径: {os.path.join(DATA_ROOT, 'training')}")
        print(f"测试集路径: {os.path.join(DATA_ROOT, 'test')}")
        print(f"训练集文件数: {len(train_dataset)}, 测试集文件数: {len(test_dataset)}")
        return None

    print(f"训练集大小: {len(train_dataset)}")
    print(f"测试集大小: {len(test_dataset)}")

    # 数据加载器 (小批量适应CPU)
    batch_size = 4  # Windows CPU可处理的大小
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # 初始化模型
    model = UNetLight(n_channels=3, n_classes=1).to(device)
    print("模型初始化完成")

    # 训练参数
    num_epochs = 20
    learning_rate = 0.0001
    weight_decay = 1e-5

    # 损失函数和优化器
    criterion = DiceBCELoss()
    optimizer = RMSprop(model.parameters(), lr=learning_rate,
                        weight_decay=weight_decay, momentum=0.9)

    # 学习率调度器
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=3, factor=0.5, verbose=True
    )

    # 记录指标
    train_losses = []
    val_losses = []
    val_dices = []
    best_dice = 0.0

    print(f"开始训练 (CPU环境)...")
    print(f"训练样本: {len(train_dataset)}, 测试样本: {len(test_dataset)}")
    print(f"批量大小: {batch_size}, 训练轮次: {num_epochs}")

    for epoch in range(num_epochs):
        start_time = time.time()
        model.train()
        epoch_train_loss = 0.0

        # 训练阶段
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, masks)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item() * images.size(0)

        # 计算平均训练损失
        epoch_train_loss /= len(train_loader.dataset)
        train_losses.append(epoch_train_loss)

        # 验证阶段
        model.eval()
        val_loss = 0.0
        total_dice = 0.0

        with torch.no_grad():
            for images, masks in test_loader:
                images = images.to(device)
                masks = masks.to(device)

                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item() * images.size(0)

                # 计算Dice系数
                preds = (outputs > 0.5).float()
                intersection = (preds * masks).sum()
                dice = (2. * intersection) / (preds.sum() + masks.sum() + 1e-6)
                total_dice += dice.item()

        # 计算平均验证指标
        val_loss /= len(test_loader.dataset)
        avg_dice = total_dice / len(test_loader)

        val_losses.append(val_loss)
        val_dices.append(avg_dice)
        scheduler.step(val_loss)

        # 保存最佳模型
        if avg_dice > best_dice:
            best_dice = avg_dice
            torch.save(model.state_dict(), "unet_retina_best.pth")
            print(f"保存最佳模型,Dice系数: {best_dice:.4f}")

        epoch_time = time.time() - start_time
        print(f"Epoch [{epoch + 1}/{num_epochs}] - "
              f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"Val Dice: {avg_dice:.4f}, Time: {epoch_time:.1f}s, "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")

    # 可视化训练指标
    plot_metrics(train_losses, val_losses, val_dices)

    # 加载最佳模型并可视化结果
    model.load_state_dict(torch.load("unet_retina_best.pth", map_location=device,weights_only=True))
    visualize_predictions(model, test_loader, num_samples=3)

    # 最终评估
    model.eval()
    total_dice = 0.0

    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = (outputs > 0.5).float()

            # 计算Dice系数
            intersection = (preds * masks).sum()
            dice = (2. * intersection) / (preds.sum() + masks.sum() + 1e-6)
            total_dice += dice.item()

    print(f"测试集平均Dice系数: {total_dice / len(test_loader):.4f}")
    return model

完整代码

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import RMSprop
import matplotlib.pyplot as plt
from PIL import Image
import random
import time

# 设置随机种子确保可重复性
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# 设备配置(使用CPU)
device = torch.device('cpu')
print(f"使用设备: {device}")

# ---------------------
# 1. 数据准备 (使用本地DRIVE视网膜血管数据集)
# ---------------------
# 数据集路径 - 使用相对路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))  # 获取当前脚本所在目录
DATA_ROOT = os.path.join(BASE_DIR, "data", "DRIVE")
print(f"数据集路径: {DATA_ROOT}")

# 数据集类
class DRIVEDataset(Dataset):
    def __init__(self, root_dir, train=True, transform=None):
        """
        Args:
            root_dir (string): 数据集根目录
            train (bool): 是否为训练集
            transform (callable, optional): 可选的变换操作
        """
        self.root_dir = root_dir
        self.transform = transform
        self.sub_dir = "training" if train else "test"

        # 图像和掩膜路径
        self.image_dir = os.path.join(root_dir, self.sub_dir, "images")
        self.mask_dir = os.path.join(root_dir, self.sub_dir, "mask")
        self.label_dir = os.path.join(root_dir, self.sub_dir, "1st_manual")

        # 获取文件列表
        self.image_files = [f for f in os.listdir(self.image_dir) if f.endswith('.tif')]
        print(f"找到 {len(self.image_files)} 个图像文件在 {self.image_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # 读取图像
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        # 读取血管标注 (标签) - 修改此处的命名逻辑
        # 从文件名中提取编号部分,假设文件名格式为"31_training.tif"
        # 提取出"31"作为编号
        file_number = img_name.split('_')[0]
        label_name = f"{file_number}_manual1.gif"
        label_path = os.path.join(self.label_dir, label_name)
        label = Image.open(label_path).convert('L')  # 转为灰度图

        # 读取视网膜区域掩膜
        mask_name = img_name.replace('.tif', '_mask.gif')
        mask_path = os.path.join(self.mask_dir, mask_name)
        mask = Image.open(mask_path).convert('L')

        # 转换为数组
        image = np.array(image)
        label = np.array(label)
        mask = np.array(mask)

        # 应用视网膜掩膜
        label = label * (mask > 0)

        # 归一化
        image = image.astype(np.float32) / 255.0
        label = (label > 0).astype(np.float32)  # 二值化

        # 转换为PyTorch张量
        image = torch.tensor(image).permute(2, 0, 1)  # [C, H, W]
        label = torch.tensor(label).unsqueeze(0)  # [1, H, W]

        # 应用变换
        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        return image, label


# ---------------------
# 2. 轻量化U-Net模型实现 (CPU优化)
# ---------------------
class DoubleConv(nn.Module):
    """(卷积 => BN => ReLU) × 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class UNetLight(nn.Module):
    """轻量化U-Net (4层下采样)"""

    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        # 下采样路径 (编码器)
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(64, 128)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(128, 256)
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(256, 512)
        )

        # 上采样路径 (解码器)
        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(256, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(128, 64)

        # 输出层
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # 编码路径
        x1 = self.inc(x)  # [64, H, W]
        x2 = self.down1(x1)  # [128, H/2, W/2]
        x3 = self.down2(x2)  # [256, H/4, W/4]
        x4 = self.down3(x3)  # [512, H/8, W/8]

        # 解码路径 (跳跃连接)
        x = self.up1(x4)  # [256, H/4, W/4]
        x = torch.cat([x, x3], dim=1)  # [512, H/4, W/4]
        x = self.conv1(x)  # [256, H/4, W/4]

        x = self.up2(x)  # [128, H/2, W/2]
        x = torch.cat([x, x2], dim=1)  # [256, H/2, W/2]
        x = self.conv2(x)  # [128, H/2, W/2]

        x = self.up3(x)  # [64, H, W]
        x = torch.cat([x, x1], dim=1)  # [128, H, W]
        x = self.conv3(x)  # [64, H, W]

        # 输出
        return self.outc(x)


# ---------------------
# 3. 训练策略与可视化工具
# ---------------------
class DiceBCELoss(nn.Module):
    """混合损失函数: Dice Loss + BCE Loss (使用 BCEWithLogitsLoss)"""

    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
        self.bce_loss = nn.BCEWithLogitsLoss()  # 使用带 logits 的 BCE 损失

    def forward(self, preds, targets):
        # 展平预测和标签
        preds = preds.view(-1)
        targets = targets.view(-1)

        # 计算Dice系数
        intersection = (torch.sigmoid(preds) * targets).sum()  # 注意这里对 preds 应用 sigmoid
        dice_loss = 1 - (2. * intersection + self.smooth) / \
                    (torch.sigmoid(preds).sum() + targets.sum() + self.smooth)

        # 计算BCE损失 (使用 BCEWithLogitsLoss,不需要手动应用 sigmoid)
        bce = self.bce_loss(preds, targets)

        return dice_loss + bce


def visualize_predictions(model, dataloader, num_samples=3):
    """可视化分割结果"""
    model.eval()

    # 获取一批样本
    images, masks = next(iter(dataloader))
    batch_size = images.shape[0]
    display_samples = min(batch_size, num_samples)

    fig, axes = plt.subplots(display_samples, 3, figsize=(15, display_samples * 4))

    images = images.to(device)

    with torch.no_grad():
        preds = model(images).cpu()

    images = images.cpu()
    masks = masks.cpu()

    for i in range(display_samples):
        # 原始图像
        if display_samples == 1:
            axes[0].imshow(images[i].permute(1, 2, 0))
            axes[0].set_title("Input Image")
            axes[0].axis('off')

            # 真实血管标注
            axes[1].imshow(masks[i].squeeze(), cmap='gray')
            axes[1].set_title("Ground Truth")
            axes[1].axis('off')

            # 预测结果
            axes[2].imshow(preds[i].squeeze() > 0.5, cmap='gray')  # 二值化阈值0.5
            axes[2].set_title("Prediction")
            axes[2].axis('off')
        else:
            axes[i, 0].imshow(images[i].permute(1, 2, 0))
            axes[i, 0].set_title("Input Image")
            axes[i, 0].axis('off')

            # 真实血管标注
            axes[i, 1].imshow(masks[i].squeeze(), cmap='gray')
            axes[i, 1].set_title("Ground Truth")
            axes[i, 1].axis('off')

            # 预测结果
            axes[i, 2].imshow(preds[i].squeeze() > 0.5, cmap='gray')  # 二值化阈值0.5
            axes[i, 2].set_title("Prediction")
            axes[i, 2].axis('off')

    plt.tight_layout()
    plt.savefig("retina_seg_results.png")
    print("分割结果已保存: retina_seg_results.png")
    plt.show()


def plot_metrics(train_losses, val_losses, val_dices):
    """绘制训练指标曲线"""
    plt.figure(figsize=(15, 5))

    # 损失曲线
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title('Training & Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Dice系数曲线
    plt.subplot(1, 2, 2)
    plt.plot(val_dices, label='Val Dice', color='green')
    plt.title('Validation Dice Coefficient')
    plt.xlabel('Epochs')
    plt.ylabel('Dice')
    plt.ylim(0, 1)
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig("training_metrics.png")
    print("训练指标已保存: training_metrics.png")
    plt.show()


# ---------------------
# 4. 训练函数
# ---------------------
def train_model():
    """模型训练主函数"""
    # 检查数据集是否存在
    if not os.path.exists(DATA_ROOT):
        print(f"错误: 数据集路径 {DATA_ROOT} 不存在!")
        print("请确保数据集已下载并放置在正确位置")
        return None

    print(f"数据集验证: {DATA_ROOT}")
    print(f"训练目录: {os.path.join(DATA_ROOT, 'training')}")
    print(f"测试目录: {os.path.join(DATA_ROOT, 'test')}")

    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # 减小尺寸以适应CPU
    ])

    # 创建数据集
    train_dataset = DRIVEDataset(DATA_ROOT, train=True, transform=transform)
    test_dataset = DRIVEDataset(DATA_ROOT, train=False, transform=transform)

    # 检查数据集是否为空
    if len(train_dataset) == 0 or len(test_dataset) == 0:
        print("错误: 未找到数据集文件!")
        print(f"训练集路径: {os.path.join(DATA_ROOT, 'training')}")
        print(f"测试集路径: {os.path.join(DATA_ROOT, 'test')}")
        print(f"训练集文件数: {len(train_dataset)}, 测试集文件数: {len(test_dataset)}")
        return None

    print(f"训练集大小: {len(train_dataset)}")
    print(f"测试集大小: {len(test_dataset)}")

    # 数据加载器 (小批量适应CPU)
    batch_size = 4  # Windows CPU可处理的大小
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # 初始化模型
    model = UNetLight(n_channels=3, n_classes=1).to(device)
    print("模型初始化完成")

    # 训练参数
    num_epochs = 20
    learning_rate = 0.0001
    weight_decay = 1e-5

    # 损失函数和优化器
    criterion = DiceBCELoss()
    optimizer = RMSprop(model.parameters(), lr=learning_rate,
                        weight_decay=weight_decay, momentum=0.9)

    # 学习率调度器
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=3, factor=0.5, verbose=True
    )

    # 记录指标
    train_losses = []
    val_losses = []
    val_dices = []
    best_dice = 0.0

    print(f"开始训练 (CPU环境)...")
    print(f"训练样本: {len(train_dataset)}, 测试样本: {len(test_dataset)}")
    print(f"批量大小: {batch_size}, 训练轮次: {num_epochs}")

    for epoch in range(num_epochs):
        start_time = time.time()
        model.train()
        epoch_train_loss = 0.0

        # 训练阶段
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, masks)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item() * images.size(0)

        # 计算平均训练损失
        epoch_train_loss /= len(train_loader.dataset)
        train_losses.append(epoch_train_loss)

        # 验证阶段
        model.eval()
        val_loss = 0.0
        total_dice = 0.0

        with torch.no_grad():
            for images, masks in test_loader:
                images = images.to(device)
                masks = masks.to(device)

                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item() * images.size(0)

                # 计算Dice系数
                preds = (outputs > 0.5).float()
                intersection = (preds * masks).sum()
                dice = (2. * intersection) / (preds.sum() + masks.sum() + 1e-6)
                total_dice += dice.item()

        # 计算平均验证指标
        val_loss /= len(test_loader.dataset)
        avg_dice = total_dice / len(test_loader)

        val_losses.append(val_loss)
        val_dices.append(avg_dice)
        scheduler.step(val_loss)

        # 保存最佳模型
        if avg_dice > best_dice:
            best_dice = avg_dice
            torch.save(model.state_dict(), "unet_retina_best.pth")
            print(f"保存最佳模型,Dice系数: {best_dice:.4f}")

        epoch_time = time.time() - start_time
        print(f"Epoch [{epoch + 1}/{num_epochs}] - "
              f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"Val Dice: {avg_dice:.4f}, Time: {epoch_time:.1f}s, "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")

    # 可视化训练指标
    plot_metrics(train_losses, val_losses, val_dices)

    # 加载最佳模型并可视化结果
    model.load_state_dict(torch.load("unet_retina_best.pth", map_location=device,weights_only=True))
    visualize_predictions(model, test_loader, num_samples=3)

    # 最终评估
    model.eval()
    total_dice = 0.0

    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = (outputs > 0.5).float()

            # 计算Dice系数
            intersection = (preds * masks).sum()
            dice = (2. * intersection) / (preds.sum() + masks.sum() + 1e-6)
            total_dice += dice.item()

    print(f"测试集平均Dice系数: {total_dice / len(test_loader):.4f}")
    return model


# ---------------------
# 5. 主执行流程
# ---------------------
if __name__ == "__main__":
    model = train_model()

三、实验结果

左图 Loss 曲线显示训练集损失逐步下降,验证集 Loss 有波动但整体趋势向好;右图 Dice 曲线随训练推进逐步上升,说明模型对血管分割的 “吻合度” 在提升,但曲线波动大,反映小数据集下泛化不稳定。

对比 Input ImageGround TruthPrediction 可知,模型能大致捕捉血管轮廓。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值