目录
1. Patch Partition & Patch Embedding
1.1 Patch Partition
2. State
2.1 Patch Merging
3. Transformer Block
3.1 window
3.2 Cyclic-shift
1. Patch Partition & Patch Embedding
1.1 Patch Partition
首先将输入的图像 [H,W,3] ,切割成patches ,每个Patch大小是[4,4,3],比如一张[3,224,224]大小的图像,会被分成 224/4 * 224/4 = 3136个patch 。这样图像的维度变成 [224/4,224/4,4*4*3] 即[56,56,48]。
1.2 Linear Embedding
Linear Embedding 将特征维度从48映射到C(=96) 。
class PatchEmbed(nn.Module) :
"""
将image split into no-overlapping patch
"""
def __init__(self,patch_size=4,in_c=3,embed_dim=96,norm_layer=None) :
super(PatchEmbed,self).__init__()
patch_size = (patch_size ,patch_size)
self.patch_size = patch_size
self.in_chans = in_c
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self,x) :
_,_,H,W = x.shape
pad_input = (H % self.patch_size[0] !=0 ) or (W % self.patch_size[1] !=0)
if pad_input :
# 需要 padding
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1], # 表示宽度方向右侧填充数
0, self.patch_size[0] - H % self.patch_size[0], # 表示高度方向底部填充数
0, 0))
x = self.proj(x)
_,_,H,W = x.shape
# H: 224//4 = 56
# W: 224//4 = 56
# flatten : [B,C,H,W] -> [B,C,H*W]
# transpose : [B,C,HW] -> [B,HW,C]
x = x.flatten(2).transpose(1,2)
x = self.norm(x)
return x,H,W
2. State
2.1 Patch Merging
在卷积神经网络中,通过池化层可以获得多尺度的特征信息。而Swin Transformer则是通过 合并相邻的patch实现。
- 第一个过程是 x = x.view(B,H,W,C) ,将 [B,H*W,C] -> [B,H,W,C]
- 第二个过程是采样,得到4个x0,x1,x2,x3 4块patches, 每个Patch 尺寸变成原来的一半
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
- 第三个过程是将4个patch 在通道维度上合并。
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
- 第四个过程是合并 H和W维度
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
这时候维度变成 [B,H/2*W/2 , 4C] ,深度从原来的C 变成了4C,加深了4倍而不是2倍,因此还需要做一次1*1的卷积。
nn.Linear(4*dim,2*dim,bias=False)
State1 输出后,再经过同样类似的操作 state2,维度变成 28*28*192,依次类推state3输出后变成14*14*384, state4后变成 7*7 *768。这样是不是和CNN网络类似,有了池化过程。
- state1 特征图大小 56*56
- state2 特征图大小 28*28
- state3 特征图大小 14*14
- state4 特征图大小 7*7
3. Transformer Block
Swin Transformer Block, 第一个Block是左边的W-MSA(窗口多头自注意力),第二个Block是SW-MSA(shift-windows MSA)。
左边和右边Block的不同之处就是所用的注意力机制稍有差异,W-MSA(图Layer l),SW-MSA(Layer l +1)。在l层(左)中,采用规则的窗口划分方案,在每个窗口内计算自注意。在下一层l + 1(右)中,窗口分区被移动,产生了新的窗口。新窗口中的自注意计算跨越了层l中以前窗口的边界,提供了它们之间的连接。
3.1 window
窗口 是指图像的特定区域内(不重叠的方式)进行自注意力计算的一组Patch(MxM个Patch)。这也意味着,在计算自注意力的时候,Patch只能与同一窗口内的patch相互作用,而不是像Vit那样整个图像的patch之间相互作用。
- patch大小是 4*4的(单位是像素),window的大小是7*7的(单位是Patch),每一个window包含7*7个patch。
上面的常规的窗口虽然可以降低计算的复杂度,但是又产生了新问题,将图像分成多个窗口后,只能窗口内计算自注意力,窗口间没办法交互,限制了模型的建模能力。为此,提出了可以跨多个窗口交互的 Shifted window 。
但是,我们发现他的窗口的数量增加了,从4个窗口变成9个窗口,计算量上来了,为此,提出Cyclic-shift 循环移位。
3.2 Cyclic-shift
具体来说,将9宫格划出来A,B,C三个区域,其中A从左上角移动到右下角,B从左边移动到右边,C从上面移动到下面。这样,重新获得4宫格如下图,但是这四个window 和原来不同的就是,窗口内有原来不属于该区域的patch,计算自注意力时,不是同一区域的patch不能计算注意力。
作者巧妙的使用 Masked技术处理窗口内和窗口间的自注意力。
4个区域分别使用不同的掩码。
举例子来说,
以Window2来说,
将窗口内的patch按照通道方向展平,红色的部分和红色部分的转置以及绿色部分和绿色部分的转置计算的才可以使用,其他的部分“扔掉”,就是使用Masked,给他赋值-100,在计算softmax就是0了。
这样W-MSA和SW-MSA模块就差不多介绍清楚了。
未完待续。。