论文114:Dynamic graph representation with knowledge-aware attention for WSI (CVPR‘24)

文章目录

  • 1 要点
  • 2 方法
    • 2.1 动态图构建
    • 2.2 知识感知注意力机制

1 要点

题目:用于WSI分类的知识感知动态图表征 (WiKG)

代码:https://github.com/WonderLandxD/WiKG

背景

  1. 已有方法重点关注重要的实例,但难以捕捉实例之间的交互性;
  2. 传统的图表征方法利用显式的空间位置构建拓扑结构,但限制了任意位置的实例之间的灵活交互能力,尤其是空间距离远时;

方法:利用动态图构建来量化WSI中区块 (实例) 之间的位置关系:

  1. 动态构建邻居和有向边缘嵌入:
    a) 使用Otsu阈值函数区分前景组织区域,并将其划分为非重叠的实例;
    b) 使用特征编码器获取实例嵌入,并使用两个单独的线性层对其嵌入,分别获得头部尾部
    c) 使用点积和softmax函数计算头部和尾部的相似度;
    d) 选择具有最高相似度得分的前 k k k个实例作为每个实例的邻居;
    e) 使用头部和尾部嵌入分配有向边嵌入;
  2. 设计了一种知识感知注意力机制:
    a) 计算邻居的尾部嵌入的线性组合,以描述一阶连续结构;
    b) 使用softmax函数对这些组合归一化;
  3. 生成图级嵌入及分类:
    a) 将聚合的邻居信息与原始头部融合,以获得新的头部表征;
    b) softmax函数获得WSI的概率分数;
  4. 交叉熵损失引导训练;

数据集

  1. TCGA:包括食管癌、肾癌,以及肺癌;

图1:表征方法示意:(a) 实例包表征;(b) 传统图表征;(c) 所提出的图表征WiKG

2 方法

本节主要介绍动态图构建和用于节点更新的知识感知注意力机制,如图2。

图2: WiKG架构,包括实例特征提取、基于头和尾的动态边构建、图表征学习,以及WSI的预测

2.1 动态图构建

与传统方法利用空间关系构建图表征,WiKG基于可学习的潜在特征来评估实例之间的位置关系:

  1. 对于给定的一个WSI,利用Otsu阈值方法来区分前景组织区域,并将其划分为不交叠的实例 X = { x 1 , x 2 , … , x n } X=\{x_1,x_2,\dots,x_n\} X={x1,x2,,xn},这也同时表示图节点
  2. 利用特征编码器,例如在ImageNet上训练的Vision Transformer来获取实例嵌入
  3. 使用两个独立的线性层来来分别将实例嵌入为头部 h i h_i hi和尾部 t i t_i ti,其中头部用于探索其它实例与其本身的相关性,尾部用于探索当前实例对其它实例的贡献:
    h i = W h f ( X ) , t i = W t f ( X ) , (1) \tag{1} h_i=W_hf(X),\quad t_i=W_tf(X), hi=Whf(X),ti=Wtf(X),(1)
  4. 计算头部和尾部的点乘,并通过softmax来获取最终的相似性:
    ω i , j = h i T t j ∑ j = 1 N ( h i T t j ) , (2) \tag{2} \omega_{i,j}=\frac{h_i^Tt_j}{\sum_{j=1}^N(h_i^Tt_j)}, ωi,j=j=1N(hiTtj)hiTtj,(2)其中 ω i , j \omega_{i,j} ωi,j表示第 i i i个头部和第 j j j个尾部之间的相似性。对于每个实例,其相似性最高的 k k k个实例被选择作为其邻居实例:
    N ( i ) = { j ∈ V : ω i , j ∈ Topk { w i , j } j = 1 N } , (3) \tag{3} \mathcal{N}(i)=\{ j\in V:\omega_{i,j}\in \text{Topk}\{w_i,j\}_{j=1}^N \}, N(i)={jV:ωi,jTopk{wi,j}j=1N},(3)其中 V V V表示实例集、 ∣ V ∣ = N |V|=N V=N,以及 ∣ N ( i ) ∣ = k |\mathcal{N}(i)|=k N(i)=k
  5. 依据拓扑结构将嵌入分配给有向边:
    r i , j = ω i , j t j + ( 1 − ω i , j ) h i , for every  j ∈ N ( i ) , (4) \tag{4} r_{i,j}=\omega_{i,j}t_j+(1-\omega_{i,j})h_i,\quad\text{for every }j\in\mathcal{N}(i), ri,j=ωi,jtj+(1ωi,j)hi,for every jN(i),(4)其中 r i , j r_{i,j} ri,j表示实例 j j j到实例 i i i的边嵌入;
  6. 通过以上步骤,每个WSI可以被描述为一个动态图表征 G = ( V , E , F , R ) G=(V,\mathcal{E,F,R}) G=(V,E,F,R),其中 V V V表示节点的集合、 E \mathcal{E} E表示边的集合、 F \mathcal{F} F表示头部和尾部嵌入的集合,以及 R \mathcal{R} R表示有向边嵌入的集合。具体地, E = { ( h , r , t ) : ( h , t ) ∈ F , r ∈ R } \mathcal{E}=\{(h,r,t):(h,t)\in\mathcal{F},r\in\mathcal{R}\} E={(h,r,t):(h,t)F,rR}是头部、尾部,以及在每个有向边上的高位相似性的集合。
  7. 算法1提供了动态图构建的Pytorch风格的伪代码。

2.2 知识感知注意力机制

为了充分利用图结构中图节点的关联性,设计用于节点之间信息传递和汇聚的知识感知注意力机制:

  1. 对于每个实例 i i i,计算头部嵌入和其邻居 N i \mathcal{N}_i Ni之间的线性组合,以刻画其一阶连接结构
    h N ( i ) = ∑ j ∈ N ( i ) π ( h i , r i , j , t j ) t j , (5) \tag{5} h_{\mathcal{N}(i)}=\sum_{j\in\mathcal{N}(i)}\pi(h_i,r_{i,j},t_j)t_j, hN(i)=jN(i)π(hi,ri,j,tj)tj,(5)其中 π ( h , r , t ) \pi(h,r,t) π(h,r,t)是引导尾部信息汇聚至头部的权重因子,这里使用三元组的非线性组合来计算:
    u ( h i , r i , j , t j ) = t j T tanh ( h i + r i , j ) . (6) \tag{6} u(h_i,r_{i,j},t_j)=t_j^T\text{tanh}(h_i+r_{i,j}). u(hi,ri,j,tj)=tjTtanh(hi+ri,j).(6)这种组合方式
  2. 利用softmax标准化:
    π ( h i , r i , j , t j ) = exp ⁡ { u ( h i , r i , j , t j ) } ∑ j ∈ N ( i ) exp ⁡ { u ( h i , r i , j , t j ) } . (7) \tag{7} \pi(h_i,r_{i,j},t_j)=\frac{\exp\{ u(h_i,r_{i,j},t_j) \}}{\sum_{j\in\mathcal{N}(i)}\exp\{ u(h_i,r_{i,j},t_j) \}}. π(hi,ri,j,tj)=jN(i)exp{u(hi,ri,j,tj)}exp{u(hi,ri,j,tj)}.(7)
    通过建模三元组的关系,并将它们描述为关于边的知识信息,头节点可以高效地度量来自尾部节点的信号。图3展示了所提出的知识感知注意力机制的实现。
  3. 使用一个双重互动机制来促进节点之间的信息交换:
    h i = σ 1 ( W 1 ( h i + h N ( i ) ) ) + σ 2 ( W 2 ( h 2 ⊙ h N ( i ) ) ) , (8) \tag{8} h_i=\sigma_1(W_1(h_i+h_{\mathcal{N}_{(i)}}))+\sigma_2(W_2(h_2\odot h_{N_{(i)}})), hi=σ1(W1(hi+hN(i)))+σ2(W2(h2hN(i))),(8)其中 σ \sigma σ是激活函数,类似于LeakyReLU;
  4. 使用Readout函数生成图级别嵌入并用softmax获得最终的WSI预测概率:
    Y ^ = Softmax ( Readout ( G ) ) , (9) \tag{9} \hat{Y}=\text{Softmax}(\text{Readout}(G)), Y^=Softmax(Readout(G)),(9)其中Readout是一个全局池化函数,类似于均值和最大池化;
  5. 在训练期间,选用交叉熵作为损失函数:
    L c e = 1 1 M ∑ m = 1 M ∑ c = 1 C Y m , c ln ⁡ ( Y ^ m , c ) , (10) \tag{10} \mathcal{L}_{ce}=1\frac{1}{M}\sum_{m=1}^M\sum_{c=1}^CY_{m,c}\ln(\hat{Y}_{m,c}), Lce=1M1m=1Mc=1CYm,cln(Y^m,c),(10)其中 C C C是类别数、 M M M是训练样本数,以及 Y Y Y是one-hot标签。
"Structure-Aware Transformer for Graph Representation Learning"是一篇使用Transformer模型进行图表示学习的论文。这篇论文提出了一种名为SAT(Structure-Aware Transformer)的模型,它利用了图中节点之间的结构信息,以及节点自身的特征信息。SAT模型在多个图数据集上都取得了非常好的结果。 以下是SAT模型的dgl实现代码,代码中使用了Cora数据集进行示例: ``` import dgl import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class GraphAttentionLayer(nn.Module): def __init__(self, in_dim, out_dim, num_heads): super(GraphAttentionLayer, self).__init__() self.num_heads = num_heads self.out_dim = out_dim self.W = nn.Linear(in_dim, out_dim*num_heads, bias=False) nn.init.xavier_uniform_(self.W.weight) self.a = nn.Parameter(torch.zeros(size=(2*out_dim, 1))) nn.init.xavier_uniform_(self.a.data) def forward(self, g, h): h = self.W(h).view(-1, self.num_heads, self.out_dim) # Compute attention scores with g.local_scope(): g.ndata['h'] = h g.apply_edges(fn.u_dot_v('h', 'h', 'e')) e = F.leaky_relu(g.edata.pop('e'), negative_slope=0.2) g.edata['a'] = torch.cat([e, e], dim=1) g.edata['a'] = torch.matmul(g.edata['a'], self.a).squeeze() g.edata['a'] = F.leaky_relu(g.edata['a'], negative_slope=0.2) g.apply_edges(fn.e_softmax('a', 'w')) # Compute output features g.ndata['h'] = h g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h')) h = g.ndata['h'] return h.view(-1, self.num_heads*self.out_dim) class SATLayer(nn.Module): def __init__(self, in_dim, out_dim, num_heads): super(SATLayer, self).__init__() self.attention = GraphAttentionLayer(in_dim, out_dim, num_heads) self.dropout = nn.Dropout(0.5) self.norm = nn.LayerNorm(out_dim*num_heads) def forward(self, g, h): h = self.attention(g, h) h = self.norm(h) h = F.relu(h) h = self.dropout(h) return h class SAT(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, num_heads): super(SAT, self).__init__() self.layer1 = SATLayer(in_dim, hidden_dim, num_heads) self.layer2 = SATLayer(hidden_dim*num_heads, out_dim, 1) def forward(self, g, h): h = self.layer1(g, h) h = self.layer2(g, h) return h.mean(0) # Load Cora dataset from dgl.data import citation_graph as citegrh data = citegrh.load_cora() g = data.graph features = torch.FloatTensor(data.features) labels = torch.LongTensor(data.labels) train_mask = torch.BoolTensor(data.train_mask) val_mask = torch.BoolTensor(data.val_mask) test_mask = torch.BoolTensor(data.test_mask) # Add self loop g = dgl.remove_self_loop(g) g = dgl.add_self_loop(g) # Define model and optimizer model = SAT(features.shape[1], 64, data.num_classes, 8) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) # Train model for epoch in range(200): model.train() logits = model(g, features) loss = F.cross_entropy(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() acc = (logits[val_mask].argmax(1) == labels[val_mask]).float().mean() if epoch % 10 == 0: print('Epoch {:03d} | Loss {:.4f} | Accuracy {:.4f}'.format(epoch, loss.item(), acc.item())) # Test model model.eval() logits = model(g, features) acc = (logits[test_mask].argmax(1) == labels[test_mask]).float().mean() print('Test accuracy {:.4f}'.format(acc.item())) ``` 在这个示例中,我们首先加载了Cora数据集,并将其转换为一个DGL图。然后,我们定义了一个包含两个SAT层的模型,以及Adam优化器。在训练过程中,我们使用交叉熵损失函数和验证集上的准确率来监控模型的性能。在测试阶段,我们计算测试集上的准确率。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值