0. Swin Transformer简介
Swin Transformer指出,CV中的Token(处理单元)的大小不固定,并且数量相较于NLP要多。为解决这两个问题,Swin Transformer不仅使用了分层结构(金字塔结构),同时还提出了一种线性复杂度的Attention计算,最终得到了几种不同大小的结构。
看图说话:
(1)使用了类似卷积神经网络中的层次化构建方法,通过金字塔分层结构和下采样拓宽了感受野,这样的backbone有助于在此基础上构建目标检测,实例分割等任务
(2)在Swin Transformer中使用了Windows Multi-Head Self-Attention(W-MSA)的概念,将特征图划分成了多个不相交的区域(Window),Multi-Head Self-Attention只在每个窗口(Window)内进行。这样的结构可以减少计算量,但也会导致不同window中的结果缺少关联性,所以又添加了滑动窗口的Windows Multi-Head Self-Attention(W-MSA)
所以Swin Transform 主要的要点就是Patch Merging, W-MSA, SW-MSA, 相对位置编码这四个方面。
Source code:
https://github.com/microsoft/Swin-Transformer
1.源码解析
由Swin Tranformer的网络结构图我们可以看出来数据处理流程:Patch Embeding -> Basic Layer (SwinTransBlock * n -> Patch Merging) -> avgpool -> mlp head
PatchEmbedding
一张图展示数据处理的方式:
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({
H}*{
W}) doesn't match model ({
self.img_size[0]}*{
self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self