【AI】Swin Transformer源码解析

0. Swin Transformer简介

Swin Transformer指出,CV中的Token(处理单元)的大小不固定,并且数量相较于NLP要多。为解决这两个问题,Swin Transformer不仅使用了分层结构(金字塔结构),同时还提出了一种线性复杂度的Attention计算,最终得到了几种不同大小的结构。
在这里插入图片描述
看图说话:
(1)使用了类似卷积神经网络中的层次化构建方法,通过金字塔分层结构和下采样拓宽了感受野,这样的backbone有助于在此基础上构建目标检测,实例分割等任务
(2)在Swin Transformer中使用了Windows Multi-Head Self-Attention(W-MSA)的概念,将特征图划分成了多个不相交的区域(Window),Multi-Head Self-Attention只在每个窗口(Window)内进行。这样的结构可以减少计算量,但也会导致不同window中的结果缺少关联性,所以又添加了滑动窗口的Windows Multi-Head Self-Attention(W-MSA)

所以Swin Transform 主要的要点就是Patch Merging, W-MSA, SW-MSA, 相对位置编码这四个方面。

Source code:

https://github.com/microsoft/Swin-Transformer

1.源码解析

由Swin Tranformer的网络结构图我们可以看出来数据处理流程:Patch Embeding -> Basic Layer (SwinTransBlock * n -> Patch Merging) -> avgpool -> mlp head

PatchEmbedding

一张图展示数据处理的方式:
在这里插入图片描述

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({
     
     H}*{
     
     W}) doesn't match model ({
     
     self.img_size[0]}*{
     
     self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值