Maximizing Mutual Information for Nuclei Cross-Domain Unsupervised Segmentation
code:https://github.com/YashSharma/MaNi
目标
在分割的projection Head中进行提取前景和背景特征进行互信息对比
网络框架:

pytorch代码
Unet模型框架
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True, domain=False, proto=False, projection=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
self.domain_classifier = nn.Sequential(
RevGrad(),
OutConv(64, 1)
)
if projection:
self.proto_projection = nn.Sequential(
OutConv(64, 64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
else:
self.proto_projection = nn.Sequential(
nn.Identity()
)
self.proto_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
)
self.domain = domain
self.proto = proto
def forward_one(self, x, x_map=None):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
domain_pred = None
proto_pred_pos = None
proto_pred_neg = None
if self.domain:
domain_pred = self.domain_classifier(x) #[b,1,128,128]
if self.proto:
# Generate Mask
if x_map == None:
logits_map_pos = torch.round(torch.sigmoid(logits)) #前景mask
logits_map_neg = 1.-logits_map_pos #背景mask
else:
logits_map_pos = x_map
logits_map_neg = 1.-logits_map_pos
#x_map_pos为前景特征,x_map_neg为背景特征
x_map_pos = self.proto_pool(self.proto_projection(x)*logits_map_pos) #self.proto_projection(x)[b,64,128,128],x_map_pos:[b,64]
x_map_neg = self.proto_pool(self.proto_projection(x)*logits_map_neg) #x_map_neg:[b,64]
normalizing_factor = np.prod(np.array(logits_map_pos.shape[-2:])) #normalizing_factor = 128*128
#proto_pred_pos or proto_pred_neg size is [b,64]
proto_pred_pos = x_map_pos*((normalizing_factor)/(logits_map_pos.sum(list(range(1, logits_map_pos.ndim))).reshape((logits_map_pos.size(0), 1))+1e-6))
proto_pred_neg = x_map_neg*((normalizing_factor)/(logits_map_neg.sum(list(range(1, logits_map_pos.ndim))).reshape((logits_map_pos.size(0), 1))+1e-6))
return logits, domain_pred, proto_pred_pos, proto_pred_neg
def forward(self, x_label, x_unlabel=None, x_label_map=None, x_unlabel_map=None, validation=False):
if validation:
logit_label, domain_label, proto_label_pos, proto_label_neg = self.forward_one(x_label)
return logit_label, domain_label, proto_label_pos, proto_label_neg
if x_unlabel == None:
return self.forward_one(x_label)
logit_label, domain_label, proto_label_pos, proto_label_neg = self.forward_one(x_label, x_label_map)
logit_unlabel, domain_unlabel, proto_unlabel_pos, proto_unlabel_neg = self.forward_one(x_unlabel, x_unlabel_map)
return logit_label, domain_label, proto_label_pos, proto_label_neg, \
logit_unlabel, domain_unlabel, proto_unlabel_pos, proto_unlabel_neg
互信息informax对比学习网络
#https://github.com/YashSharma/MaNi/blob/master/infomax_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import einsum
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Infomax Loss =========================================
class GlobalDiscriminator(nn.Module):
def __init__(self, sz):
super(GlobalDiscriminator, self).__init__()
self.l0 = nn.Linear(sz, 512)
self.l1 = nn.Linear(512, 512)
self.l2 = nn.Linear(512, 1)
def forward(self, x1, x2):
x = torch.cat((x1, x2), dim=1)
h = F.relu(self.l0(x))
h = F.relu(self.l1(h))
return self.l2(h)
class GlobalDiscriminatorConv(nn.Module):
def __init__(self, sz):
super(GlobalDiscriminatorConv, self).__init__()
self.l0 = nn.Conv2d(sz, 128, kernel_size=1)
self.l1 = nn.Conv2d(128, 128, kernel_size=1)
self.l2 = nn.Conv2d(128, 1, kernel_size=1)
def forward(self, x1, x2):
x = torch.cat((x1, x2), dim=1)
h = F.relu(self.l0(x))
h = F.relu(self.l1(h))
return self.l2(h)
class PriorDiscriminator(nn.Module):
def __init__(self, sz):
super(PriorDiscriminator, self).__init__()
self.l0 = nn.Linear(sz, 1000)
self.l1 = nn.Linear(1000, 200)
self.l2 = nn.Linear(200, 1)
def forward(self, x):
h = F.relu(self.l0(x))
h = F.relu(self.l1(h))
return torch.sigmoid(self.l2(h))
class MILinearBlock(nn.Module):
def __init__(self, feature_sz, units=2048, bln=True):
super(MILinearBlock, self).__init__()
# Pre-dot product encoder for "Encode and Dot" arch for 1D feature maps
self.feature_nonlinear = nn.Sequential(
nn.Linear(feature_sz, units, bias=False),
nn.BatchNorm1d(units),
nn.ReLU(),
nn.Linear(units, units),
)
self.feature_shortcut = nn.Linear(feature_sz, units)
self.feature_block_ln = nn.LayerNorm(units)
# initialize the initial projection to a sort of noisy copy
eye_mask = np.zeros((units, feature_sz), dtype=np.bool)
for i in range(feature_sz):
eye_mask[i, i] = 1
self.feature_shortcut.weight.data.uniform_(-0.01, 0.01)
self.feature_shortcut.weight.data.masked_fill_(
torch.tensor(eye_mask), 1.0)
self.bln = bln
def forward(self, feat):
f = self.feature_nonlinear(feat) + self.feature_shortcut(feat)
if self.bln:
f = self.feature_block_ln(f)
return f
class GlobalDiscriminatorDot(nn.Module):
def __init__(self, sz, units=2048, bln=True):
super(GlobalDiscriminatorDot, self).__init__()
self.block_a = MILinearBlock(sz, units=units, bln=bln)
self.block_b = MILinearBlock(sz, units=units, bln=bln)
self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(
self,
features1=None,
features2=None,
):
# Computer cross modal loss
feat1 = self.block_a(features1)
feat2 = self.block_b(features2)
feat1, feat2 = map(lambda t: F.normalize(
t, p=2, dim=-1), (feat1, feat2))
# ## Method 1
# # Dot product and sum
# o = torch.sum(feat1 * feat2, dim=1) * self.temperature.exp()
# ## Method 2
# o = self.cos(feat1, feat2) * self.temperature.exp()
# Method 3
o = einsum("n d, n d -> n", feat1, feat2) * self.temperature.exp()
return o
class DeepInfoMaxLoss(nn.Module):
def __init__(self, type="concat"):
super().__init__()
if type=="concat":
self.global_d = GlobalDiscriminator(sz=64+64)
elif type=="dot":
self.global_d = GlobalDiscriminatorDot(sz=64)
else:
self.global_d = GlobalDiscriminatorConv(sz=64+64)
def forward(self, proto_label_pos, proto_label_neg, proto_unlabel_pos):
Ej = -F.softplus(-self.global_d(proto_unlabel_pos, proto_label_pos)).mean()
Em = F.softplus(self.global_d(proto_unlabel_pos, proto_label_neg)).mean()
LOSS = (Em - Ej)
return LOSS
主程序的使用
#https://github.com/YashSharma/MaNi/blob/master/train.py
logits1, domain_pred1, proto_pred_pos1, proto_pred_neg1 = model(x1)
logits2, domain_pred2, proto_pred_pos2, proto_pred_neg2 = model(x2)
loss_proto = loss_fn(proto_pred_pos1, proto_pred_neg1, proto_pred_pos2)