论文链接:https://arxiv.org/pdf/1710.09829v1.pdf
代码实现参考:https://github.com/hula-ai/capsule_network_dynamic_routing/blob/master/capsnet.py
capsule网络motivation
capsule网络提出的原因是针对CNN进行特征提取的时候会忽略特征之间的关系,也就是说一张image里打乱region,CNN仍然能够正确识别。参考:https://baijiahao.baidu.com/s?id=1585376284135321218&wfr=spider&for=pc
原因是CNN网络通过卷积和池化虽然能解决平移不变性和一定程度的旋转不变性和放缩不变性,但是没有考虑到特征之间的关系如位置、大小。
参考:https://zhuanlan.zhihu.com/p/35154414
以人脸识别为例:
CNN模型通常只会考虑如果存在鼻子、眼睛和嘴巴等的特征就会激活神经元,判断这张脸为人脸。但它其实没有考虑到底层特征和更高一层特征的方向位置和大小的关系。
CapsNet的每个神经元不再是标量,而是一个向量。每个向量都隐式的包含一些信息,比如概率,方向,大小,比如眼睛和人脸的关系,眼睛的大小是人脸的1/10,而且方向是跟人脸是一致的。只有眼睛等等这些特征跟要预测的人脸满足这种正向的对应关系,才能对预测做出贡献。
capsule网络流程
每一个layer包含多个capsule,下面展示layer
l
l
l到下一个层
l
+
1
l+1
l+1的路由算法:
具体步骤可以参考原始论文,这里只讲述重点:
(1)
b
i
j
b_{ij}
bij表示capsule
i
i
i在聚合信息到capsule
j
j
j时未进行softmax归一化的权重,初始化为0表示初始权重是一样的。
(2)
u
i
\mathbf{u}_{i}
ui表示capsule
i
i
i的输出,
u
^
j
∣
i
\mathbf{\hat{u}}_{j|i}
u^j∣i表示capsule
i
i
i向capsule
j
j
j传递的向量,通过一个转换矩阵
W
i
j
\mathbf{W}_{ij}
Wij施加在
u
i
\mathbf{u}_{i}
ui得到。对于capsule
j
j
j,对上一layer传来的信息进行加权的聚合得到向量
s
j
\mathbf{s}_{j}
sj
(3)对向量
s
j
\mathbf{s}_{j}
sj进行激活:
v
j
=
∣
∣
s
j
∣
∣
2
1
+
∣
∣
s
j
∣
∣
2
s
j
∣
∣
s
j
∣
∣
\mathbf{v}_{j}=\frac{||\mathbf{s}_{j}||^{2}}{1+||\mathbf{s}_{j}||^{2}}\frac{\mathbf{s}_{j}}{||\mathbf{s}_{j}||}
vj=1+∣∣sj∣∣2∣∣sj∣∣2∣∣sj∣∣sj
前面的
∣
∣
s
j
∣
∣
2
1
+
∣
∣
s
j
∣
∣
2
∈
[
0
,
1
)
\frac{||\mathbf{s}_{j}||^{2}}{1+||\mathbf{s}_{j}||^{2}}\in [0,1)
1+∣∣sj∣∣2∣∣sj∣∣2∈[0,1)为压缩因子,
s
j
∣
∣
s
j
∣
∣
\frac{\mathbf{s}_{j}}{||\mathbf{s}_{j}||}
∣∣sj∣∣sj为单位向量,因此
∣
∣
v
j
∣
∣
∈
[
0
,
1
)
||\mathbf{v}_{j}||\in [0,1)
∣∣vj∣∣∈[0,1),且向量方向是不变的。
(4)使用点积来计算
v
j
\mathbf{v}_{j}
vj与每个capsule
i
i
i的输出
u
^
j
∣
i
\mathbf{\hat{u}}_{j|i}
u^j∣i的一致程度,根据余弦相似性原理,两个向量在方向上越接近,点积越大, 同时没有实施模归一化,考虑了
∣
∣
u
^
j
∣
i
∣
∣
||\mathbf{\hat{u}}_{j|i}||
∣∣u^j∣i∣∣越大,权重也越大。然后反馈权重
b
i
j
b_{ij}
bij。
margin loss来进行类别判断
这里借鉴了SVM里的hinge loss,对于capsule
k
k
k的输出,也就是类别
k
k
k的判别向量
v
k
\mathbf{v}_{k}
vk,向量的模长
∣
∣
v
k
∣
∣
||\mathbf{v}_{k}||
∣∣vk∣∣表示了capsule对应的类别存在的概率。对于正确的类别,
∣
∣
v
k
∣
∣
||\mathbf{v}_{k}||
∣∣vk∣∣趋向于越大越好,称为present loss,对于其他类别,
∣
∣
v
k
∣
∣
||\mathbf{v}_{k}||
∣∣vk∣∣趋向于越小越好,成为absent loss。假设某个
L
k
=
T
k
max
(
0
,
m
+
−
∣
∣
v
k
∣
∣
)
2
+
λ
(
1
−
T
k
)
max
(
0
,
∣
∣
v
k
∣
∣
−
m
−
)
L_{k}=T_{k}\max{(0,m^{+}-||\mathbf{v}_{k}||)}^{2}+\lambda (1-T_{k})\max(0,||\mathbf{v}_{k}||-m^{-})
Lk=Tkmax(0,m+−∣∣vk∣∣)2+λ(1−Tk)max(0,∣∣vk∣∣−m−)
前半部分是present loss,
T
k
=
1
T_{k}=1
Tk=1如果是真实类别。对于10-class的mnist数据集,假设某个样本的真实标签是3,那么误差为:
L
=
max
(
0
,
m
+
−
∣
∣
v
3
∣
∣
)
2
+
λ
∑
i
∈
{
c
∣
c
≠
3
,
c
∈
[
0
,
9
]
}
max
(
0
,
∣
∣
v
i
∣
∣
−
m
−
)
L=\max{(0,m^{+}-||\mathbf{v}_{3}||)}^{2}+\lambda \sum_{i\in\{c|c\neq 3,c\in[0,9]\}}\max(0,||\mathbf{v}_{i}||-m^{-})
L=max(0,m+−∣∣v3∣∣)2+λi∈{c∣c=3,c∈[0,9]}∑max(0,∣∣vi∣∣−m−)
实例化capsule网络
作者论文中展示了3层的CapsNet来实现MNIST分类,架构如下:
(1)第一层是一个普通的卷积层,卷积核为9×9,输入channel是3,输出channel是256。
(2)第二层为PrimaryCaps层,首先是一个卷积核为9×9,输入channel是256,输出channel是256的卷积,得到一个6×6×256的feature map,并沿着channel dimension分成32份feature maps,因此每一个feature map的维度为6×6×8。在这里我们视为具有6×6×32个capsules,因此每一个capsule的输出向量维度为8.
class PrimaryCapsLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, cap_dim, num_cap_map):
super(PrimaryCapsLayer, self).__init__()
self.capsule_dim = cap_dim
self.num_cap_map = num_cap_map
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0)
def forward(self, x):
batch_size = x.size(0)
outputs = self.conv_out(x)
map_dim = outputs.size(-1)
outputs = outputs.view(batch_size, self.capsule_dim, self.num_cap_map, map_dim, map_dim) # [bs, 8, 32, 6, 6]
outputs = outputs.view(batch_size, self.capsule_dim, -1).transpose(-1, -2) # [bs, 32*6*6, 8]
outputs = squash(outputs)
return outputs
(3)第三层为DigitCaps层,进行PrimaryCaps层输出向量的加权聚合。
class DigitCapsLayer(nn.Module):
def __init__(self, num_digit_cap, num_prim_cap, in_cap_dim, out_cap_dim, num_iterations):
super(DigitCapsLayer, self).__init__()
self.num_iterations = num_iterations
self.W = nn.Parameter(0.01 * torch.randn(1, num_prim_cap, num_digit_cap, out_cap_dim, in_cap_dim))
# [1, 32*6*6, 10, 16, 8]
def forward(self, x):
batch_size = x.size(0) # [bs, 32*6*6, 8]
u = x[:, :, None, :, None]
u_hat = torch.matmul(self.W, u) # [bs, 32*6*6, 10, 16]
# detach u_hat during routing iterations to prevent gradients from flowing
temp_u_hat = u_hat.detach()
b_ij = torch.zeros(batch_size, u_hat.size(1), u_hat.size(2), 1, 1).cuda()
for i in range(self.num_iterations - 1):
c_ij = F.softmax(b_ij, dim=2)
s_j = (c_ij * temp_u_hat).sum(dim=1, keepdim=True) # [bs, 1, 10, 16, 1]
v = squash(s_j, dim=-2)
# [bs, 1152, 10, 16, 1]T . [bs, 1, 10, 16, 1]
u_produce_v = torch.matmul(temp_u_hat.transpose(-1, -2), v)
b_ij = b_ij + u_produce_v
# [bs, 1152, 10, 1, 1]
c_ij = F.softmax(b_ij, dim=2)
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) # *u_hat for gradient update
v = squash(s_j, dim=-2)
return v
DigitCaps层一共有10个capsule,每个capsule输出16维的向量,作者分析了每一维的意义,比如某一维代表数字的宽度,几个维度表示variations的组合。
(4)在inference阶段,通过判断DigitCaps层的class-specific向量的模长来判断类别,模越大,表明类别概率越高。