神经网络学习小记录67——Pytorch版 Vision Transformer(VIT)模型的复现详解
- 学习前言
- 什么是Vision Transformer(VIT)
- 代码下载
- Vision Transforme的实现思路
- 一、整体结构解析
- 二、网络结构解析
- 1、特征提取部分介绍
- a、Patch+Position Embedding
- b、Transformer Encoder
- I、Self-attention结构解析
- II、Self-attention的矩阵运算
- III、MultiHead多头注意力机制
- IV、TransformerBlock的构建。
- c、整个VIT模型的构建
- 2、分类部分
- Vision Transforme的构建代码
学习前言
视觉Transformer最近非常的火热,从VIT开始,我先学学看。
什么是Vision Transformer(VIT)
Vision Transformer是Transformer的视觉版本,Transformer基本上已经成为了自然语言处理的标配,但是在视觉中的运用还受到限制。
Vision Transformer打破了这种NLP与CV的隔离,将Transformer应用于图像图块(patch)序列上,进一步完成图像分类任务。简单来理解,Vision Transformer就是将输入进来的图片,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列,将组合后的结果传入Transformer特有的Multi-head Self-attention进行特征提取。最后利用Cls Token进行分类。
代码下载
Github源码下载地址为:
https://github.com/bubbliiiing/classification-pytorch 复制该路径到地址栏跳转。
Vision Transforme的实现思路
一、整体结构解析
与寻常的分类网络类似,整个Vision Transformer可以分为两部分,一部分是特征提取部分,另一部分是分类部分。
在特征提取部分,VIT所做的工作是特征提取。特征提取部分在图片中的对应区域是Patch+Position Embedding和Transformer Encoder。Patch+Position Embedding的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。在获得序列信息后,传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。
在分类部分,VIT所做的工作是利用提取到的特征进行分类。在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。
二、网络结构解析
1、特征提取部分介绍
a、Patch+Position Embedding
Patch+Position Embedding的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。
该部分首先对输入进来的图片进行分块处理,处理方式其实很简单,使用的是现成的卷积。由于卷积使用的是滑动窗口的思想,我们只需要设定特定的步长,就可以输入进来的图片进行分块处理了。
在VIT中,我们常设置这个卷积的卷积核大小为16x16,步长也为16x16,此时卷积就会每隔16个像素点进行一次特征提取,由于卷积核大小为16x16,两个图片区域的特征提取过程就不会有重叠。当我们输入的图片是224, 224, 3的时候,我们可以获得一个14, 14, 768的特征层。
下一步就是将这个特征层组合成序列,组合的方式非常简单,就是将高宽维度进行平铺,14, 14, 768在高宽维度平铺后,获得一个196, 768的特征层。平铺完成后,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,图中的这个0*就是Cls Token,我们此时获得一个197, 768的特征层。
添加完成Cls Token后,再为所有特征添加上位置信息,这样网络才有区分不同区域的能力。添加方式其实也非常简单,我们生成一个197, 768的参数矩阵,这个参数矩阵是可训练的,把这个矩阵加上197, 768的特征层即可。
到这里,Patch+Position Embedding就构建完成了,构建代码如下:
class PatchEmbed(nn.Module):
def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
super().__init__()
self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(num_features) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class VisionTransformer(nn.Module):
def __init__(
self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
):
super().__init__()
#-----------------------------------------------#
# 224, 224, 3 -> 196, 768
#-----------------------------------------------#
self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
num_patches = (224 // patch_size) * (224 // patch_size)
self.num_features = num_features
self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
#--------------------------------------------------------------------------------------------------------------------#
# classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
# 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
# 196, 768 -> 197, 768
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
#--------------------------------------------------------------------------------------------------------------------#
# 为网络提取到的特征添加上位置信息。
# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
# 197, 768 -> 197, 768
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
cls_token_pe = self.pos_embed[:, 0:1, :]
img_token_pe = self.pos_embed[:, 1: , :]
img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
x = self.pos_drop(x + pos_embed)
b、Transformer Encoder
在上一步获得shape为197, 768的序列信息后,将序列信息传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。
I、Self-attention结构解析
看懂Self-attention结构,其实看懂下面这个动图就可以了,动图中存在一个序列的三个单位输入,每一个序列单位的输入都可以通过三个处理(比如全连接)获得Query、Key、Value,Query是查询向量、Key是键向量、Value值向量。
如果我们想要获得input-1的输出,那么我们进行如下几步:
1、利用input-1的查询向量,分别乘上input-1、input-2、input-3的键向量,此时我们获得了三个score。
2、然后对这三个score取softmax,获得了input-1、input-2、input-3各自的重要程度。
3、然后将这个重要程度乘上input-1、input-2、input-3的值向量,求和。
4、此时我们获得了input-1的输出。
如图所示,我们进行如下几步:
1、input-1的查询向量为[1, 0, 2],分别乘上input-1、input-2、input-3的键向量,获得三个score为2,4,4。
2、然后对这三个score取softmax,获得了input-1、input-2、input-3各自的重要程度,获得三个重要程度为0.0,0.5,0.5。
3、然后将这个重要程度乘上input-1、input-2、input-3的值向量,求和,即
。
4、此时我们获得了input-1的输出 [2.0, 7.0, 1.5]。
上述的例子中,序列长度仅为3,每个单位序列的特征长度仅为3,在VIT的Transformer Encoder中,序列长度为197,每个单位序列的特征长度为768 // num_heads。但计算过程是一样的。在实际运算时,我们采用矩阵进行运算。
II、Self-attention的矩阵运算
实际的矩阵运算过程如下图所示。我以实际矩阵为例子给大家解析:
输入的Query、Key、Value如下图所示:
首先利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。输出的每一行,都代表input-1、input-2、input-3,对当前input的贡献,我们对这个贡献值取一个softmax。
然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
这个矩阵运算的代码如下所示,各位同学可以自己试试。
import numpy as np
def soft_max(z):
t = np.exp(z)
a = np.exp(z) / np.expand_dims(np.sum(t, axis=1), 1)
return a
Query = np.array([
[1,0,2],
[2,2,2],
[2,1,3]
])
Key = np.array([
[0,1,1],
[4,4,0],
[2,3,1]
])
Value = np.array([
[1,2,3],
[2,8,0],
[2,6,3]
])
scores = Query @ Key.T
print(scores)
scores = soft_max(scores)
print(scores)
out = scores @ Value
print(out)
III、MultiHead多头注意力机制
多头注意力机制的示意图如图所示:
这幅图给人的感觉略显迷茫,我们跳脱出这个图,直接从矩阵的shape入手会清晰很多。
在第一步进行图像的分割后,我们获得的特征层为197, 768。
在施加多头的时候,我们直接对196, 768的最后一维度进行分割,比如我们想分割成12个头,那么矩阵的shepe就变成了196, 12, 64。
然后我们将196, 12, 64进行转置,将12放到前面去,获得的特征层为12, 196, 64。之后我们忽略这个12,把它和batch维度同等对待,只对196, 64进行处理,其实也就是上面的注意力机制的过程了。
#--------------------------------------------------------------------------------------------------------------------#
# Attention机制
# 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
# 然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
# 然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
#--------------------------------------------------------------------------------------------------------------------#
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
IV、TransformerBlock的构建。
在完成MultiHeadSelfAttention的构建后,我们需要在其后加上两个全连接。就构建了整个TransformerBlock。
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = (drop, drop)
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
c、整个VIT模型的构建
整个VIT模型由一个Patch+Position Embedding加上多个TransformerBlock组成。典型的TransforerBlock的数量为12个。
class VisionTransformer(nn.Module):
def __init__(
self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
):
super().__init__()
#-----------------------------------------------#
# 224, 224, 3 -> 196, 768
#-----------------------------------------------#
self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
num_patches = (224 // patch_size) * (224 // patch_size)
self.num_features = num_features
self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
#--------------------------------------------------------------------------------------------------------------------#
# classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
# 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
# 196, 768 -> 197, 768
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
#--------------------------------------------------------------------------------------------------------------------#
# 为网络提取到的特征添加上位置信息。
# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
# 197, 768 -> 197, 768
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
self.pos_drop = nn.Dropout(p=drop_rate)
#-----------------------------------------------#
# 197, 768 -> 197, 768 12次
#-----------------------------------------------#
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(
*[
Block(
dim = num_features,
num_heads = num_heads,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop = drop_rate,
attn_drop = attn_drop_rate,
drop_path = dpr[i],
norm_layer = norm_layer,
act_layer = act_layer
)for i in range(depth)
]
)
self.norm = norm_layer(num_features)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
cls_token_pe = self.pos_embed[:, 0:1, :]
img_token_pe = self.pos_embed[:, 1: , :]
img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
x = self.pos_drop(x + pos_embed)
x = self.blocks(x)
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def freeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = False
except:
module.requires_grad = False
def Unfreeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = True
except:
module.requires_grad = True
2、分类部分
在分类部分,VIT所做的工作是利用提取到的特征进行分类。
在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。
最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。
class VisionTransformer(nn.Module):
def __init__(
self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
):
super().__init__()
#-----------------------------------------------#
# 224, 224, 3 -> 196, 768
#-----------------------------------------------#
self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
num_patches = (224 // patch_size) * (224 // patch_size)
self.num_features = num_features
self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
#--------------------------------------------------------------------------------------------------------------------#
# classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
# 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
# 196, 768 -> 197, 768
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
#--------------------------------------------------------------------------------------------------------------------#
# 为网络提取到的特征添加上位置信息。
# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
# 197, 768 -> 197, 768
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
self.pos_drop = nn.Dropout(p=drop_rate)
#-----------------------------------------------#
# 197, 768 -> 197, 768 12次
#-----------------------------------------------#
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(
*[
Block(
dim = num_features,
num_heads = num_heads,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop = drop_rate,
attn_drop = attn_drop_rate,
drop_path = dpr[i],
norm_layer = norm_layer,
act_layer = act_layer
)for i in range(depth)
]
)
self.norm = norm_layer(num_features)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
cls_token_pe = self.pos_embed[:, 0:1, :]
img_token_pe = self.pos_embed[:, 1: , :]
img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
x = self.pos_drop(x + pos_embed)
x = self.blocks(x)
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def freeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = False
except:
module.requires_grad = False
def Unfreeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = True
except:
module.requires_grad = True
Vision Transforme的构建代码
import math
from collections import OrderedDict
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
#--------------------------------------#
# Gelu激活函数的实现
# 利用近似的数学公式
#--------------------------------------#
class GELU(nn.Module):
def __init__(self):
super(GELU, self).__init__()
def forward(self, x):
return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3))))
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
super().__init__()
self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(num_features) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
#--------------------------------------------------------------------------------------------------------------------#
# Attention机制
# 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
# 然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
# 然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
#--------------------------------------------------------------------------------------------------------------------#
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = (drop, drop)
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
def __init__(
self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
):
super().__init__()
#-----------------------------------------------#
# 224, 224, 3 -> 196, 768
#-----------------------------------------------#
self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
num_patches = (224 // patch_size) * (224 // patch_size)
self.num_features = num_features
self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
#--------------------------------------------------------------------------------------------------------------------#
# classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
# 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
# 196, 768 -> 197, 768
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
#--------------------------------------------------------------------------------------------------------------------#
# 为网络提取到的特征添加上位置信息。
# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
# 197, 768 -> 197, 768
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
self.pos_drop = nn.Dropout(p=drop_rate)
#-----------------------------------------------#
# 197, 768 -> 197, 768 12次
#-----------------------------------------------#
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(
*[
Block(
dim = num_features,
num_heads = num_heads,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop = drop_rate,
attn_drop = attn_drop_rate,
drop_path = dpr[i],
norm_layer = norm_layer,
act_layer = act_layer
)for i in range(depth)
]
)
self.norm = norm_layer(num_features)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
cls_token_pe = self.pos_embed[:, 0:1, :]
img_token_pe = self.pos_embed[:, 1: , :]
img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
x = self.pos_drop(x + pos_embed)
x = self.blocks(x)
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def freeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = False
except:
module.requires_grad = False
def Unfreeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = True
except:
module.requires_grad = True
def vit(input_shape=[224, 224], pretrained=False, num_classes=1000):
model = VisionTransformer(input_shape)
if pretrained:
model.load_state_dict(torch.load("model_data/vit-patch_16.pth"))
if num_classes!=1000:
model.head = nn.Linear(model.num_features, num_classes)
return model