参考文献:Chen, Z., Bei, Y. & Rudin, C. Concept whitening for interpretable image recognition. Nat Mach Intell 2, 772–782 (2020). https://doi.org/10.1038/s42256-020-00265-z
项目代码链接:https://github.com/zhiCHEN96/ConceptWhitening
1. 概括
这篇文章提出了一个叫做Concept Whitening(CW)的模块,翻译作“概念白化”。CW模块能够将深度神经网络的隐空间解耦合,并且赋予每个维度一个人为定义的“概念”。加入了CW模块的深度神经网络在性能上没有很大的差别,但是具有了更好的可解释性。
2. 概念 Concept
所谓的概念,可以理解成原始数据当中提取出来的初级特征。例如,对于一个图片场景分类的任务,终极的特征是图片当中的场景,而初级特征可以是图片当中所包含的物品、人物等。理论上来说,根据一张图片中出现的物品种类,是可以推测出图片当中的场景的。我们把这种初级特征,就叫做“概念”。本篇文献的一个假设就是,经过学习后的深度神经,可以提取出原始数据中的“概念”,并根据所提取出的“概念”进行分类任务。也就是如下图所示:
3.概念在隐空间中的表征
假设我们有这样一个深度神经网络函数fff,它在X=Rn\mathcal{X}=\mathbb{R}^nX=Rn上有定义,它的值分布在Y=Rm\mathcal{Y}=\mathbb{R}^mY=Rm。现在我们把这个网络拆分成浅层和深层部分:f=g∘Φf=g\circ\Phif=g∘Φ,其中Φ\PhiΦ是浅层部分,ggg是深层部分。那么我们将会得到一个新的空间:Z=Φ(X)\mathcal{Z}=\Phi(\mathcal{X})Z=Φ(X)(严格来说,应该是Φ(X)∈Z\Phi(\mathcal{X})\in\mathcal{Z}Φ(X)∈Z,因为Φ(X)\Phi(\mathcal{X})Φ(X)不一定能够铺满整个Z\mathcal{Z}Z空间)。这个空间是X\mathcal{X}X变换到Y\mathcal{Y}Y的一个中间态,我们称其为隐空间。再接下来的讨论里面,我们假设dim(Z)=ddim(\mathcal{Z})=ddim(Z)=d.
现在有这样一个分类器训练问题,即给定一批样本D={xi,yi}i=1N\mathcal{D}=\{x_i,y_i\}_{i=1}^ND={xi,yi}i=1N,其中xi∈Xx_i\in\mathcal{X}xi∈X,而yiy_iyi是表示类别序号的整数,一共有MMM种类别。要求通过这些样本,将神经网络函数fθf_\thetafθ训练成为一个分类器。我们首先假设,所有的xix_ixi,都是X\mathcal{X}X空间上的概率密度分布p(x)p(x)p(x)的采样。而对于所有的yi=j∈[1,M]y_i=j\in[1,M]yi=j∈[1,M],xix_ixi都是X\mathcal{X}X空间上的类条件概率密度分布pcj(x)=p(x∣cj)p_{c_j}(x)=p(x|c_j)pcj(x)=p(x∣cj)的采样。
现在我们人为的规定出kkk个概念(k<dk<dk<d),分别是c1,...,ckc_1,...,c_kc1,...,ck. 我们按照数据是否含有对应的概念,从集合{xi}i=1N\{x_i\}_{i=1}^N{xi}i=1N里抽取元素,构造出kkk个子集X1,X2,...,XkX_1,X_2,...,X_kX1,X2,...,Xk. 其中对于任意xix_ixi,若xix_ixi含有概念cjc_jcj,则xi∈Xjx_i\in X_jxi∈Xj,否则xi∉Xjx_i\notin X_jxi∈/Xj.
在以上这些前提假设之下,我们希望经过一定算法训练之后的神经网络fff,它的隐空间Z\mathcal{Z}Z要具有对概念的表征能力。也就是要满足这样一个假设:
假设3.1:隐空间Z\mathcal{Z}Z上存在一组基E={e1,e2,...,ed}E=\{e_1,e_2,..., e_d\}E={e1,e2,...,ed},对于每一个zi=Φ(xi),xi∈Dz_i=\Phi(x_i), x_i\in \mathcal{D}zi=Φ(xi),xi∈D,都可以被这组线性表出为:zi=λ1e1+...+λkek+...+λdedz_i=\lambda_1 e_1+...+\lambda_k e_k+...+\lambda_d e_dzi=λ1e1+...+λkek+...+λded. 并且满足:
(1) 若xi∈Xjx_i\in X_jxi∈Xj,则λj\lambda_jλj相对较大,否则λj\lambda_jλj相对较小。
(2) 在xi∼p(x)x_i\sim p(x)xi∼p(x)的前提下,∀j1,j2∈[1,k]且j1≠j2,有Cov(λj1,λj2)=0\forall j_1,j_2\in[1,k]且j_1\neq j_2, 有\text{Cov}(\lambda_{j_1},\lambda_{j_2})=0∀j1,j2∈[1,k]且j1=j2,有Cov(λj1,λj2)=0,其中Cov\text{Cov}Cov表示变量之间的协方差。
在这个假设里面,第一个条件的“相对较大”和“相对较小”是比较主观的判断,并没有一个客观标准。大致来说就是,对于xi∈Xjx_i\in X_jxi∈Xj,它的λj\lambda_jλj相对而言,要比xk∉Xjx_k\notin X_jxk∈/Xj的λj\lambda_jλj要更大。
4.概念白化
上面我们虽然已经用假设3.1,较为严谨地定义了“隐空间能够表征概念”的具体含义,但是这个假设确实不容易验证的。一方面是“相对较大”和“相对较小”这样的表述存在一定的模糊性,另一方面是因为这样一组基EEE,在Z\mathcal{Z}Z的坐标系下面,并不与坐标轴重合,这就导致要去寻找EEE是较为困难的。
然而,我们注意到,以EEE作为坐标轴所构建出来的新空间V(E)V(E)V(E),只是隐空间Z\mathcal{Z}Z的另一种表示方法。假设同一个向量α\alphaα在空间V(E)V(E)V(E)和Z\mathcal{Z}Z下分别具有数组坐标vvv和zzz,那么一定存在可逆矩阵,使得v=Azv = Azv=Az. 所以我们自然而然地会想到,通过一定算法来找到这个矩阵AAA。倘若能够找到这个变换矩阵,那么我们就可以将z∈Zz\in\mathcal{Z}z∈Z变换到v∈V(E)v\in V(E)v∈V(E),使得数据是否含有概念cjc_jcj直观地表现在第jjj个神经元的响应值上面。
下面介绍一种称为白化的算法。
4.1.白化
还是和前面一样,我们把隐空间Z\mathcal{Z}Z上面,样本xix_ixi变换后得到的特征zi=Φ(xi)z_i=\Phi(x_i)zi=Φ(xi)视为是概率分布p(z)p(z)p(z)的采样。这个概率分布具有协方差矩阵Σ∈Rn×n\Sigma\in\mathbb{R}^{n\times n}Σ∈Rn×n,根据协方差的计算规则,Σ\SigmaΣ是实对称矩阵。
假设我们得到的是一批样本特征的矩阵:z=(z1,z2,...,zb)∈Rd×b\mathbf{z} = (z_1,z_2,...,z_b)\in\mathbb{R}^{d\times b}z=(z1,z2,...,zb)∈Rd×b,那么我们可以估算出协方差矩阵Σ=1b⋅z⋅zT\Sigma=\frac{1}{b}\cdot z\cdot z^TΣ=b1⋅z⋅zT.
我们尝试找到一个矩阵WWW,满足WT⋅W=Σ−1W^T\cdot W=\Sigma^{-1}WT⋅W=Σ−1。这个矩阵WWW就被称为白化矩阵。这一步可以通过很多方法实现,其中最常用的方法就是ZCAZCAZCA白化,此处不赘述,可以查阅相关资料了解。
找到这个矩阵之后,我们对特征进行变换:z^=W(z−μ)\hat{z}=W(z-\mu)z^=W(z−μ),其中μ\muμ是zzz的均值。这一步就被称为白化操作(Whitening)。
白化操作的结果是,变换后的随机变量z^\hat{z}z^,其协方差矩阵变成了单位矩阵:Σ^=WΣWT=Id×d\hat{\Sigma}=W\Sigma W^T=I_{d\times d}Σ^=WΣWT=Id×d
并且白化矩阵还有一个特点,那就是白化矩阵左乘以任何的正交矩阵,其结果依旧是白化矩阵,证明如下:
(QW)T⋅QW=WTQTQW=WTIW=Σ−1(QW)^T\cdot QW=W^TQ^TQW=W^TIW=\Sigma^{-1}(QW)T⋅QW=WTQTQW=WTIW=Σ−1
4.2.目标坐标系
由于经过白化操作后的数据,其协方差矩阵是单位矩阵。而前面我们提到,在V(E)V(E)V(E)空间下,有Cov(λj1,λj2)=0,∀j1,j2∈[1,k]且j1≠j2\text{Cov}(\lambda_{j_1},\lambda_{j_2})=0, \forall j_1,j_2\in [1,k]且j_1\neq j_2Cov(λj1,λj2)=0,∀j1,j2∈[1,k]且j1=j2。也就是说,数据在V(E)V(E)V(E)空间下的协方差矩阵,其左上角的k×kk \times kk×k子块是对角矩阵。那么我们就思考,能否通过一个白化操作,使得隐空间的原坐标系变换到EEE呢?可惜并不行,因为白化操作所得到的坐标系下,数据的协方差矩阵一定是严格的单位矩阵。然而假设3.1当中,V(E)V(E)V(E)下数据的协方差矩阵,只能够写成这样的形式:
Σ=[ΛAATS]\Sigma = \begin{bmatrix} \Lambda &A\\ A^T &S \end{bmatrix}Σ=[ΛATAS]
其中Λ=diag(λ1,λ2,...,λk)\Lambda = diag(\lambda_1,\lambda_2,...,\lambda_k)Λ=diag(λ1,λ2,...,λk),A∈Rk×(d−k)A\in \mathbb{R}^{k\times(d-k)}A∈Rk×(d−k)是一个普通的的矩阵,S∈R(d−k)×(d−k)S\in \mathbb{R}^{(d-k)\times(d-k)}S∈R(d−k)×(d−k)是一个实对称矩阵。
那么我们退一步想,在所有符合假设3.1的基{E}\{E\}{E}里面,是否存在一组基E∗E^*E∗,在它构成的坐标系下,数据的协方差矩阵就是单位矩阵呢?答案是肯定的,而且这组基可以从任意一个满足假设3.1的基EEE出发,推导出来。下面就介绍这个推导的方法:
(1) 想要保证由EEE变换得到的新坐标系E∗E^*E∗,依旧满足假设3.1的条件(1),最好使得EEE的前kkk个基向量方向不发生改变,仅仅改变其模长。所以我们想到的第一个变换,就是将EEE的前kkk个向量长度进行一定改变,使得在这些方向上,数据投影的方差为单位1. 这个变换用矩阵来表示,就是P1=[Λ−1200I]P_1 = \begin{bmatrix}\Lambda^{-\frac{1}{2}} &0\\ 0 &I\end{bmatrix}P1=[Λ−2100I],其中Λ−12=diag(1λ1,...,1λk)\Lambda^{-\frac{1}{2}}=diag(\frac{1}{\sqrt{\lambda_1}}, ... , \frac{1}{\sqrt{\lambda_k}})Λ−21=diag(λ11,...,λk1). 经过这一步变换,协方差变为P1⋅[ΛAATS]⋅P1T=[IΛ−12AATΛ−12S]P_1\cdot\begin{bmatrix} \Lambda &A\\ A^T &S \end{bmatrix}\cdot P_1^T=\begin{bmatrix}I &\Lambda^{-\frac{1}{2}}A\\ A^T\Lambda^{-\frac{1}{2}} &S\end{bmatrix}P1⋅[ΛATAS]⋅P1T=[IATΛ−21Λ−21AS]
(2) 对协方差矩阵进行成对的线性变换,用左上角的单位阵,将非对角矩阵块消去为0. 用矩阵的形式表示就是P2=[I0−ATΛ−12I]P_2=\begin{bmatrix}I &0\\ -A^T\Lambda^{-\frac{1}{2}} &I\end{bmatrix}P2=[I−ATΛ−210I],经过变换后的协方差矩阵是P2[IΛ−12AATΛ−12S]P2T=[I00S−ATΛ−1A]P_2\begin{bmatrix}I &\Lambda^{-\frac{1}{2}}A\\ A^T\Lambda^{-\frac{1}{2}} &S\end{bmatrix}P_2^T=\begin{bmatrix}I &0\\ 0 &S-A^T\Lambda^{-1}A\end{bmatrix}P2[IATΛ−21Λ−21AS]P2T=[I00S−ATΛ−1A].
(3) 找到实对称矩阵S−ATΛ−1AS-A^T\Lambda^{-1}AS−ATΛ−1A的白化矩阵DDD(一定存在,而且可以通过特征分解找到),令P3=[I00D]P_3=\begin{bmatrix}I &0\\ 0 &D\end{bmatrix}P3=[I00D],经过变换后的协方差矩阵是P3[I00S−ATΛ−1A]P3T=Id×dP_3\begin{bmatrix}I &0\\ 0 &S-A^T\Lambda^{-1}A\end{bmatrix}P_3^T=I_{d\times d}P3[I00S−ATΛ−1A]P3T=Id×d
所以,令P=P3P2P1P=P_3P_2P_1P=P3P2P1,则E∗=E⋅P−1E^*=E\cdot P^{-1}E∗=E⋅P−1就是我们要找的目标坐标系。这个坐标系,可以通过白化操作,由原坐标系变换得到。将隐空间的原坐标系,变换到能够用第iii个神经元激活值来表示数据包含或不包含概念cic_ici的坐标系E∗E^*E∗,所使用到的白化操作,在这里被称为概念白化(Concetp Whitening)。
5.CW模块
参考文献创造性地提出了一种深度神经网络模块:CW层。CW层能够较好的拟合4.2部分我们所期望得到的概念白化操作。下面我们一步步地来推导出CW层的训练算法。
5.1.两阶段的概念白化
虽然直接求出概念白化矩阵是困难的,但是求出一个白化矩阵却是简单的。想要找到隐空间Z\mathcal{Z}Z上的一个白化矩阵,只需要根据所有的样本{zi}i=1N\{z_i\}_{i=1}^N{zi}i=1N求出协方差矩阵Σ\SigmaΣ,然后用ZCAZCAZCA白化操作,就能够得到一个白化矩阵WWW. 那么剩下的事情,就是找到一个旋转矩阵QTQ^TQT,使得白化矩阵WWW被旋转到目标的概念白化矩阵。在这种情况下,一个概念白化操作,可以表示为旋转矩阵左乘以一个白化矩阵:QTWQ^TWQTW。也就是我们把概念白化操作当成了两个阶段:白化和旋转。假定我们已经得到了白化矩阵WWW,我们可以先构造一个衡量QQQ优劣的损失函数:
(1)令Ψ(z;W,μ)=W(z−μ)\Psi(z;W,\mu)=W(z-\mu)Ψ(z;W,μ)=W(z−μ)
(2)F=∑j=1k∑xi(cj)∈XjqjTΨ(Φ(xi(cj);θ);W,μ)F=\sum_{j=1}^k\sum_{x_i^{(c_j)}\in X_j}q_j^T\Psi(\Phi(x_i^{(c_j)};\theta);W,\mu)F=∑j=1k∑xi(cj)∈XjqjTΨ(Φ(xi(cj);θ);W,μ),其中qjq_jqj指的是QQQ的第jjj列,而qjTΨ(Φ(xi(cj);θ);W,μ)q_j^T\Psi(\Phi(x_i^{(c_j)};\theta);W,\mu)qjTΨ(Φ(xi(cj);θ);W,μ)其实就是变换后的向量QTW⋅Φ(xi(cj))Q^TW\cdot \Phi(x_i^{(c_j)})QTW⋅Φ(xi(cj))的第jjj个元素值。FFF的值越大,说明QTWQ^TWQTW越符合要求。
现在我们要求解最优化问题:maxq1,...,qjF(Q)max_{q_1, ... ,q_j} F(Q)maxq1,...,qjF(Q)s.t.QTQ=Ids.t. Q^TQ=I_ds.t.QTQ=Id
这就是理论上寻找概念白化矩阵的算法。
5.2.基于小批量数据的白化矩阵计算
虽然在5.1里面,我们说可以对所有训练数据求解出ZCAZCAZCA白化矩阵,但是这样做会使得算法耗时过长。一般在深度神经网络的训练当中,我们习惯于将数据分成若干个小批次(batch)。以一个批次为单位,进行参数的更新。那么我们可不可以对于每一个批次的数据,都单独计算它们的白化矩阵WWW和均值μ\muμ,作为概念白化层所使用的参数呢?也不可以,因为白化层后面所对接的深层网络需要根据白化层的输出进行训练。如果每一个批次的数据都单独使用一个白化矩阵和均值,容易使得网络的随机性过大,后续的深层网络难以收敛。所以CW层使用的方法,是指数滑动平均值(Exponential Moving Average)。用伪代码的形式来书写,就是这样的:
5.3.更新旋转矩阵
5.1当中给出了旋转矩阵最优化问题的形式化表达:maxq1,...,qjF(Q)max_{q_1, ... ,q_j} F(Q)maxq1,...,qjF(Q)s.t.QTQ=Ids.t. Q^TQ=I_ds.t.QTQ=Id
但是并没有给出解决该最优化问题的算法。
实际上,CW模块当中,也是通过随机梯度下降的思路来更新旋转矩阵的。但是由于需要保证旋转矩阵的正交性,所以采用了一种称为在黎曼流形上的梯度下降(Gradient Methods on the Stiefeld Manifold)来进行更新。具体算法如下:
事实上,原文里面,旋转矩阵更新时候的学习率η\etaη并不是一个固定的参数,而是会在每一次更新之后,获取一个新的值。采用更新的η\etaη可以加速训练过程,但是这里为了简单起见,暂时认为η\etaη是固定参数。
5.4.CW模块的完整结构
我们用伪代码来表示CW模块的所有成员参数,还有各参数的初始值,如下:
带有CW模块的深度神经网络,和一般的深度网络一样,可以通过SGD算法进行训练。CW模块对输入数据的前向操作是一个线性变换,所以其反向算法是很容易得出的。CW模块本身的参数并不参与整个网络的SGD调参,其白化矩阵是通过指数滑动平均值更新的,而旋转矩阵则是会在固定的时刻(如每20个周期的训练后),从概念子集中抽取出采样,根据其本身的更新算法进行更新。整个网络的训练算法如下:
6.结束
本篇博客分析了参考文献里面,CW模块的具体构成、更新算法,以及CW模块背后蕴含的数理思想。文献的结果部分会在后续博客当中进行分析。