一、从问题出发:医学图像分割的挑战
在医学图像分割中,最大的问题是:
-
前景(如细胞、器官)区域小
-
背景区域非常大(比如黑色背景、组织间隙)
这时候,常规的损失函数(比如 BCE
)可能会因为背景正确率很高而掩盖前景预测差的问题。
于是引入了 Dice Loss ——一种关注“重叠区域”的损失函数。
二、Dice 系数的直观理解
我们希望预测出来的前景区域和真实标签的前景区域“越重合越好”。
Dice 系数计算的是:
预测和真实之间重叠的面积 × 2 / 总面积
公式如下:
Dice(A,B)=2*∣A∩B∣ / |A|+|B|
-
A:预测的前景区域(预测图中值为1的像素)
-
B:真实的前景区域(标签中值为1的像素)
-
∣A∩B∣:预测和标签都为1的地方(重叠区域)
三、转换成代码层面的公式
在神经网络中,预测输出是一个浮点型概率图 p∈[0,1],标签是二值图 y∈{0,1}。Dice 计算写成:
-
pip_ipi:第 i 个像素的预测概率
-
yiy_iyi:第 i 个像素的真实标签(0 或 1)
四、Dice Loss 的 PyTorch 实现详细讲解
下面我们来逐行讲解 PyTorch 实现:
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-7):
super(DiceLoss, self).__init__()
self.smooth = smooth
-
定义一个 DiceLoss 类,继承自
nn.Module
。 -
smooth
是一个极小的数,防止分母为 0。
def forward(self, preds, targets):
preds = torch.sigmoid(preds) # 把网络输出转成概率
-
preds
是模型的输出,通常是 logit 值(未经过 sigmoid)。 -
所以这里加上
torch.sigmoid()
把它转为概率 p∈[0,1]p \in [0, 1]p∈[0,1]
preds = preds.view(-1)
targets = targets.view(-1)
-
把 2D(或 3D)图像展平为 1D 向量。
-
这样方便统一计算,不管图像大小。
intersection = (preds * targets).sum()
-
preds * targets
表示重叠的部分(预测为1且标签为1) -
.sum()
得到所有像素重叠的总值,相当于 Dice 中的 ∣A∩B∣|A \cap B|∣A∩B∣
dice_score = (2. * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
-
这是 Dice 系数的核心公式。
-
分子是重叠区域 × 2,分母是预测前景 + 真实前景
return 1 - dice_score
-
损失函数要“越小越好”
-
所以用
1 - Dice
,如果重叠越多(dice 越大),损失就越小
总结关键点
步骤 | 解释 |
---|---|
sigmoid | 将输出转换为概率 |
view(-1) | 展平图像数据 |
intersection | 求预测和真实标签的重叠区域 |
smooth | 防止除以 0 的技巧 |
1 - dice | 将重叠度转为损失函数 |
五、适用场景
场景 | 推荐 |
---|---|
医学图像小目标 | ✅ 推荐使用 Dice Loss 或 Dice + BCE |
背景太多时 | ✅ 可以避免 BCE 偏向背景 |
多分类 | ❌ Dice 需要适配成 soft-Dice 或 one-hot + 类别循环 |
为什么要组合 Dice + BCE?
Dice Loss 对于 前景较少、样本不平衡 的任务很友好,但它有一些弱点:
问题 | 原因 |
---|---|
收敛不稳定 | Dice 是归一化比值,对 early stage 的小预测不敏感 |
没有像素级指导 | 它不逐像素比较,只关注区域重叠度 |
对小区域不敏感 | 若预测和真实都接近 0,Dice 值可能仍然很高 |
BCE(Binary Cross Entropy)则:
-
强调每一个像素的分类正确性
-
学习初期对小目标更敏感
-
收敛较快
所以组合优势互补:
-
BCE → 提供逐像素监督(稳定训练)
-
Dice → 强调结构整体重叠(优化性能)
二、直观理解 Dice + BCE
组合方式(一般为加权和):
-
α 和 β 是权重,一般初始设为 0.5 和 0.5,或都为 1。
-
可以理解为“局部(像素) + 全局(结构)”结合训练目标。
三、PyTorch 实现代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiceBCELoss(nn.Module):
def __init__(self, weight_dice=1.0, weight_bce=1.0, smooth=1e-6):
super(DiceBCELoss, self).__init__()
self.weight_dice = weight_dice
self.weight_bce = weight_bce
self.smooth = smooth
self.bce = nn.BCEWithLogitsLoss() # 更安全,内部自动加 sigmoid
初始化部分说明:
-
weight_dice
和weight_bce
:两个损失函数的权重比例 -
使用
BCEWithLogitsLoss()
是更推荐的做法,它内部包含sigmoid
操作
def forward(self, preds, targets):
# BCE Loss
bce_loss = self.bce(preds, targets)
-
preds
是网络输出的 logits,targets
是二值标签 -
BCEWithLogitsLoss
自动对 preds 做 sigmoid
# Dice Loss
preds = torch.sigmoid(preds)
preds = preds.view(-1)
targets = targets.view(-1)
intersection = (preds * targets).sum()
dice_loss = 1 - (2. * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
-
和之前讲的一样:计算 Dice Loss
-
注意和 BCE 是分开计算的
# 加权组合
total_loss = self.weight_bce * bce_loss + self.weight_dice * dice_loss
return total_loss
四、怎么调节权重?
训练目标 | 推荐权重设置 |
---|---|
强调初期稳定训练 | BCE:Dice = 1.0:0.5 或 1.0:1.0 |
结构更重要(如细胞) | BCE:Dice = 0.5:1.0 或 0.3:1.0 |
不知道 | 从 1:1 开始实验,观察结果调节 |
建议使用可视化工具(如 TensorBoard)查看损失下降曲线。
五、实践效果分析
评估点 | 组合损失的优势 |
---|---|
收敛速度 | 比单独 Dice 快很多 |
最终精度 | 通常高于单独 BCE |
小目标分割 | 效果较好(尤其在医学图像) |
鲁棒性 | 对不同病例更稳健,不易过拟合背景 |
六、总结:适合用 Dice + BCE 的情况
推荐使用的场景:
-
医学图像分割(肿瘤、细胞等)
-
小目标不平衡任务(前景少、背景多)
-
对结构和边界要求高的任务
一、IoU(交并比)基本概念
1. 定义:
IoU(Intersection over Union)即“交并比”,用于衡量两个区域之间的重叠程度:
其中:
-
A 表示预测的分割掩膜(预测 mask);
-
B 表示真实的分割掩膜(ground truth);
-
∩:表示交集,即两者都为正的像素;
-
∪:表示并集,即两者至少有一个为正的像素。
2. 范围:
-
IoU 值在 0~1 之间,越大表示预测越准确。
二、为什么要使用 IoU Loss?
交并比(IoU)是常用的评估指标之一,但它不可导,无法直接作为损失函数使用,因此我们使用一种近似的、可微的方式来构造 IoU Loss。
IoU Loss 主要解决如下问题:
-
对目标区域较小的图像很敏感;
-
能避免类别不平衡对损失函数的干扰;
-
直接优化模型在评估时使用的指标(IoU 本身)。
三、IoU Loss 数学表达式
为了让 IoU 可导,我们使用了“软”版本的交集和并集:
设:
-
p:模型的预测输出(经过 sigmoid,范围 [0,1]);
-
g:ground truth 掩膜(0/1);
软交集和并集如下:
其中 ϵ 是一个很小的数,防止除以 0。
最终的 IoU Loss 为:
四、PyTorch 实现
import torch
import torch.nn as nn
class IoULoss(nn.Module):
def __init__(self, smooth=1e-6):
super(IoULoss, self).__init__()
self.smooth = smooth
def forward(self, preds, targets):
# 预测和标签都需为 float 类型
preds = preds.contiguous()
targets = targets.contiguous()
intersection = (preds * targets).sum(dim=(2, 3))
union = (preds + targets - preds * targets).sum(dim=(2, 3))
iou = (intersection + self.smooth) / (union + self.smooth)
loss = 1 - iou
return loss.mean()
1. class IoULoss(nn.Module):
-
自定义一个损失函数类,继承
nn.Module
,这样可以像其它损失函数一样被调用(如criterion(pred, target)
)。
2. def __init__(self, smooth=1e-6):
-
构造函数,传入
smooth
参数,默认值很小(1e−61e{-6}1e−6),用于防止除以 0(即在分母加上epsilon
)。
3. super(IoULoss, self).__init__()
-
调用父类的初始化方法,这是标准写法,用于正确注册
nn.Module
的子模块(例如你的 loss 中可能包含其它参数)。
4. def forward(self, preds, targets):
-
前向传播定义。当你调用
loss = IoULoss()(preds, targets)
时,会自动执行这段 forward。
5. preds = preds.contiguous()
-
**目的:**确保预测张量在内存中是连续的。某些操作(如
.view()
)要求输入数据是连续的。 -
通常作用不大,但作为规范性处理可以加上。
6. intersection = (preds * targets).sum(dim=(2, 3))
-
对每个像素点做乘法,表示预测和标签都为 1 的位置,即“交集”。
-
.sum(dim=(2, 3))
是在 H 和 W 两个维度上进行求和,保留的是每个样本的结果(batch size 不变)。
7. union = (preds + targets - preds * targets).sum(dim=(2, 3))
-
表达式来源于集合公式:
A∪B=A+B−A∩BA \cup B = A + B - A \cap BA∪B=A+B−A∩B -
预测 + 标签 - 交集,得到并集区域。
8. iou = (intersection + self.smooth) / (union + self.smooth)
-
得到每张图的 IoU 值。加上
smooth
是为了防止 0 除错误。 -
**注意:**这个是 soft IoU,不是硬 0/1 的,而是基于概率(预测 mask 是 [0,1])。
9. loss = 1 - iou
-
IoU 越大越好,所以我们取反作为损失:IoU 越接近 1,损失越小。
10. return loss.mean()
-
返回所有样本的平均 IoU Loss。
✅ 输入格式要求
为了正常使用这个 IoU Loss,有几个要求:
条件 | 要求 |
---|---|
preds | 模型输出(通过 sigmoid,范围为 0~1) |
targets | Ground Truth(通常为 0 或 1 的二值 mask) |
维度 | 一般为 (B, 1, H, W) |
类型 | float32 ,不能是 int 类型,否则乘法出错 |