互信息的正负样本对比学习

Maximizing Mutual Information for Nuclei Cross-Domain Unsupervised Segmentation

code:https://github.com/YashSharma/MaNi

目标

在分割的projection Head中进行提取前景和背景特征进行互信息对比

网络框架:

pytorch代码

  1. Unet模型框架

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True, domain=False, proto=False, projection=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

        self.domain_classifier = nn.Sequential(
            RevGrad(),
            OutConv(64, 1)
        )
        
        if projection:
            self.proto_projection = nn.Sequential(
                OutConv(64, 64),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True)
            )
        else:
            self.proto_projection = nn.Sequential(
                nn.Identity()
            )
                                 
        self.proto_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
        )
        
        self.domain = domain
        self.proto = proto

    def forward_one(self, x, x_map=None):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        
        domain_pred = None
        proto_pred_pos = None
        proto_pred_neg = None        
        
        if self.domain:
            domain_pred = self.domain_classifier(x) #[b,1,128,128]
        
        if self.proto:
            # Generate Mask
            if x_map == None:
                logits_map_pos = torch.round(torch.sigmoid(logits)) #前景mask
                logits_map_neg = 1.-logits_map_pos #背景mask
            else:                
                logits_map_pos = x_map
                logits_map_neg = 1.-logits_map_pos
            
            #x_map_pos为前景特征,x_map_neg为背景特征
            x_map_pos = self.proto_pool(self.proto_projection(x)*logits_map_pos) #self.proto_projection(x)[b,64,128,128],x_map_pos:[b,64]
            x_map_neg = self.proto_pool(self.proto_projection(x)*logits_map_neg) #x_map_neg:[b,64]
            
            
            normalizing_factor = np.prod(np.array(logits_map_pos.shape[-2:])) #normalizing_factor = 128*128
            
            
            #proto_pred_pos or proto_pred_neg size is [b,64]
            proto_pred_pos = x_map_pos*((normalizing_factor)/(logits_map_pos.sum(list(range(1, logits_map_pos.ndim))).reshape((logits_map_pos.size(0), 1))+1e-6))
            proto_pred_neg = x_map_neg*((normalizing_factor)/(logits_map_neg.sum(list(range(1, logits_map_pos.ndim))).reshape((logits_map_pos.size(0), 1))+1e-6))

            
        return logits, domain_pred, proto_pred_pos, proto_pred_neg
    
    def forward(self, x_label, x_unlabel=None, x_label_map=None, x_unlabel_map=None, validation=False):
        if validation:
            logit_label, domain_label,  proto_label_pos, proto_label_neg = self.forward_one(x_label)    
            return logit_label, domain_label, proto_label_pos, proto_label_neg            
        
        if x_unlabel == None:
            return self.forward_one(x_label)
        
        logit_label, domain_label, proto_label_pos, proto_label_neg = self.forward_one(x_label, x_label_map)
        logit_unlabel, domain_unlabel, proto_unlabel_pos, proto_unlabel_neg = self.forward_one(x_unlabel, x_unlabel_map)
        
        return logit_label, domain_label, proto_label_pos, proto_label_neg, \
                logit_unlabel, domain_unlabel, proto_unlabel_pos, proto_unlabel_neg

  1. 互信息informax对比学习网络

#https://github.com/YashSharma/MaNi/blob/master/infomax_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch import einsum

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Infomax Loss =========================================
class GlobalDiscriminator(nn.Module):
    def __init__(self, sz):
        super(GlobalDiscriminator, self).__init__()
        self.l0 = nn.Linear(sz, 512)
        self.l1 = nn.Linear(512, 512)
        self.l2 = nn.Linear(512, 1)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        h = F.relu(self.l0(x))
        h = F.relu(self.l1(h))
        return self.l2(h)
    
class GlobalDiscriminatorConv(nn.Module):
    def __init__(self, sz):
        super(GlobalDiscriminatorConv, self).__init__()
        self.l0 = nn.Conv2d(sz, 128, kernel_size=1)
        self.l1 = nn.Conv2d(128, 128, kernel_size=1)
        self.l2 = nn.Conv2d(128, 1, kernel_size=1)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        h = F.relu(self.l0(x))
        h = F.relu(self.l1(h))
        return self.l2(h)    
    
class PriorDiscriminator(nn.Module):
    def __init__(self, sz):
        super(PriorDiscriminator, self).__init__()
        self.l0 = nn.Linear(sz, 1000)
        self.l1 = nn.Linear(1000, 200)
        self.l2 = nn.Linear(200, 1)

    def forward(self, x):
        h = F.relu(self.l0(x))
        h = F.relu(self.l1(h))
        return torch.sigmoid(self.l2(h))    
        
class MILinearBlock(nn.Module):
    def __init__(self, feature_sz, units=2048, bln=True):
        super(MILinearBlock, self).__init__()
        # Pre-dot product encoder for "Encode and Dot" arch for 1D feature maps
        self.feature_nonlinear = nn.Sequential(
            nn.Linear(feature_sz, units, bias=False),
            nn.BatchNorm1d(units),
            nn.ReLU(),
            nn.Linear(units, units),
        )
        self.feature_shortcut = nn.Linear(feature_sz, units)
        self.feature_block_ln = nn.LayerNorm(units)

        # initialize the initial projection to a sort of noisy copy
        eye_mask = np.zeros((units, feature_sz), dtype=np.bool)
        for i in range(feature_sz):
            eye_mask[i, i] = 1

        self.feature_shortcut.weight.data.uniform_(-0.01, 0.01)
        self.feature_shortcut.weight.data.masked_fill_(
            torch.tensor(eye_mask), 1.0)
        self.bln = bln

    def forward(self, feat):
        f = self.feature_nonlinear(feat) + self.feature_shortcut(feat)
        if self.bln:
            f = self.feature_block_ln(f)

        return f
    
class GlobalDiscriminatorDot(nn.Module):
    def __init__(self, sz, units=2048, bln=True):
        super(GlobalDiscriminatorDot, self).__init__()
        self.block_a = MILinearBlock(sz, units=units, bln=bln)
        self.block_b = MILinearBlock(sz, units=units, bln=bln)
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(
        self,
        features1=None,
        features2=None,
    ):

        # Computer cross modal loss
        feat1 = self.block_a(features1)
        feat2 = self.block_b(features2)

        feat1, feat2 = map(lambda t: F.normalize(
            t, p=2, dim=-1), (feat1, feat2))

        # ## Method 1
        # # Dot product and sum
        # o = torch.sum(feat1 * feat2, dim=1) * self.temperature.exp()

        # ## Method 2
        # o = self.cos(feat1, feat2) * self.temperature.exp()

        # Method 3
        o = einsum("n d, n d -> n", feat1, feat2) * self.temperature.exp()

        return o    
    
class DeepInfoMaxLoss(nn.Module):
    def __init__(self, type="concat"):
        super().__init__()
        if type=="concat":
            self.global_d = GlobalDiscriminator(sz=64+64)
        elif type=="dot":
            self.global_d = GlobalDiscriminatorDot(sz=64)
        else:
            self.global_d = GlobalDiscriminatorConv(sz=64+64)

    def forward(self, proto_label_pos, proto_label_neg, proto_unlabel_pos):
        Ej = -F.softplus(-self.global_d(proto_unlabel_pos, proto_label_pos)).mean()
        Em = F.softplus(self.global_d(proto_unlabel_pos, proto_label_neg)).mean()
        LOSS = (Em - Ej)
        
        return LOSS

  1. 主程序的使用

#https://github.com/YashSharma/MaNi/blob/master/train.py

logits1, domain_pred1, proto_pred_pos1, proto_pred_neg1 = model(x1)
logits2, domain_pred2, proto_pred_pos2, proto_pred_neg2 = model(x2)
loss_proto = loss_fn(proto_pred_pos1, proto_pred_neg1, proto_pred_pos2)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值