当天空一只鸟飞过去的时候,往往注意力会追随者鸟儿,天空在视觉系统中,自然成为了一个背景信息。
计算机视觉中的注意力机制(attention)的基本思想就是想让系统学会注意力——能够忽略无关信息而关注重点信息。在深度学习发展的今天,搭建能够具备注意力机制的神经网络则开始显得更加重要,一方面是这种神经网络能够自主学习注意力机制,另一方面是注意力机制能够反过来帮助我们去理解神经网络看到的世界。近几年来,深度学习与视觉注意力机制结合的研究工作,大多数是集中于使用掩码(mask)来形成注意力机制。掩码的原理在于通过另一层新的权重,将图片数据中关键的特征标识出来,通过学习训练,让深度学习网络学到每一张新图片中需要关注的区域,也就是形成了注意力。这种思想,进而演化成两种不同类型的注意力,一种是软注意力机制,另一种是强注意力机制。软注意力的关键点在于,这种注意力更加关注区域或者通道,而且软注意力是确定性注意力,学习完成后直接可以通过网络生成,最关键的地方是软注意力是可微的。强注意力更关注点,也就是图像中的每个点都有可能延申出注意力,同时强注意力是一个随机的预测过程,更强调动态变化,当然,最关键的是强注意力是一个不可微的注意力。注意力域主要有三种:空间域(spatial domain),通道域(channel domain),混合域(mixed domain)。
1.软注意力的注意力域 (soft attention)
1.1 空间域(Spatial domain)
设计思路:通过注意力机制,将原始图片的空间信息变换到另一个空间中并保留了关键信息。因为卷积神经网络中的池化层直接用最大池化或者平均池化的方法,将图片信息压缩,减少运算量提升准确率。但是这样的池化方法太过于暴力,直接将信息合并会导致关键信息无法识别出来,所以提出了一个叫空间转换器(spatial transformer)的模块,将图片中的空间域信息做对应的空间变换,从而将关键信息提取出来。
比如这个直观的实验图:a列是原始图片信息,其中第一个手写数字7没有做任何变换,第二个手写数字5,做了一定的旋转变化,而第三个手写数字6,加上了一些噪声信号;b列中的彩色边框是学习到的spatial transformer的框盒(bounding box),每一个框盒其实就是对应图片学习出来的一个spatial transformer;c列中是通过spatial transformer转换之后的特征图,可以看出7的关键区域被选择出来,5被旋转成为了正向的图片,6的噪声信息没有被识别进入。最终可以通过这些转换后的特征图来预测出d列中手写数字的数值。
spatial transformer其实就是注意力机制的实现,因为训练的空间转换器能够找出图片信息中需要被关注的区域,同时这个转换器又能够具有旋转、缩放变换的功能,这样图片的局部重要信息能够通过变换而被框盒提取出来。
模型结构如下:
这是空间变换网络(spatial transformer network)中最重要的空间变换模块,这个模块可以作为新的层直接加入到原有的网络结构,比如ResNet中。来仔细研究这个模型的输入:。神经网络训练中使用的数据类型都是张量(tensor),H是上一层tensor的高度(height),W是上一层tensor的宽度(width),而C代表tensor的通道(channel),比如图片基本的三通道(RGB),或者是经过卷积层(convolutional layer)之后,不同卷积核(kernel)都会产生不同的通道信息。之后这个输入进入两条路线,一条路线是信息进入定位网络(localisation net),另一条路线是原始信号直接进入采样层(sampler)。其中定位网络会学习到一组参数θ,而这组参数就能够作为网格生成器(grid generator)的参数,生成一个采样信号,这个采样信号其实是一个变换矩阵Tθ(G),与原始图片相乘之后,可以得到变换之后的矩阵V。
通过这张转换图片,可以看出空间转换器中产生的采样矩阵是能够将原图中关键的信号提取出来,(a)中的采样矩阵是单位矩阵,不做任何变换,(b)中的矩阵是可以产生缩放旋转变换的采样矩阵。
最右边式子左边的θ矩阵就是对应的采样矩阵。
STN 主要可以分为三个部分:
- Localisation net:是一个自己定义的网络,它输入U,输出变化参数ΘΘ,这个参数用来映射U和V的坐标关系
- Grid generator:根据V中的坐标点和变化参数ΘΘ,计算出U中的坐标点。这里是因为V的大小是自己先定义好的,当然可以得到V的所有坐标点,而填充V中每个坐标点的像素值的时候,要从U中去取,所以根据V中每个坐标点和变化参数ΘΘ进行运算,得到一个坐标。在sampler中就是根据这个坐标去U中找到像素值,这样子来填充V
- Sampler:要做的是填充V,根据Grid generator得到的一系列坐标和原图U(因为像素值要从U中取)来填充,因为计算出来的坐标可能为小数,要用另外的方法来填充,比如双线性插值。