目录
1. Patch Partition & Patch Embedding
1. Patch Partition & Patch Embedding
1.1 Patch Partition
首先将输入的图像 [H,W,3] ,切割成patches ,每个Patch大小是[4,4,3],比如一张[3,224,224]大小的图像,会被分成 224/4 * 224/4 = 3136个patch 。这样图像的维度变成 [224/4,224/4,4*4*3] 即[56,56,48]。
1.2 Linear Embedding
Linear Embedding 将特征维度从48映射到C(=96) 。
class PatchEmbed(nn.Module) :
"""
将image split into no-overlapping patch
"""
def __init__(self,patch_size=4,in_c=3,embed_dim=96,norm_layer=None) :
super(PatchEmbed,self).__init__()
patch_size = (patch_size ,patch_size)
self.patch_size = patch_size
self.in_chans = in_c
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self,x) :
_,_,H,W = x.shape
pad_input = (H % self.patch_size[0] !=0 ) or (W % self.patch_size[1] !=0)
if pad_input :
# 需要 padding
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1], # 表示宽度方向右侧填充数
0, self.patch_size[0] - H % self.patch_size[0], # 表示高度方向底部填充数
0, 0))