MViT模型
1.多头池化注意力(MHPA)
Multi Head Pooling Attention是本文的核心,它使得多尺度变换器已逐渐变化的时空分辨率进行操作。与原始的多头注意力(MHA)不同,在原始的多头注意力中,通道维度和时空分辨率保持不变,MHPA将潜在张量序列合并,以减少参与输入的序列长度(分辨率)。如下图所示,
Transformer只能处理1维数据,video通过 patch处理后形状改变为(L,D)即图中的THW,Self-attention计算公式主要是,假设,则
为了使公式成立,必须保证,即图中THW,所以为了降低空间分辨率,只需要改变Q向量的序列长度,所以对Q向量进行pooling操作即可,同时实验证明K,V向量pooling会提高指标,所以对K,V向量也进行了pooling操作,但是不会影响空间分辨率的大小,为了保证res connection成立,需要对输入X同样进行和Q向量一样的pooling操作
pooling操作又分为max/ average/ conv等,论文实验部分对不同的pooling操作进行了消融实验,最终确定为333 核的conv pooling操作。
如何提高通道数?
提高通道数就是通过简单的全连接层对向量维度D进行映射即可
代码
# pool通常是MaxPool3d或AvgPool3d
def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
if pool is None:
return tensor, thw_shape
tensor_dim = tensor.ndim
if tensor_dim == 4:
pass
elif tensor_dim == 3:
tensor = tensor.unsqueeze(1)
else:
raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
if has_cls_embed:
cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
B, N, L, C = tensor.shape
T, H, W = thw_shape
tensor = (tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous())
# 执行pooling操作
tensor = pool(tensor)
thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
if has_cls_embed:
tensor = torch.cat((cls_tok, tensor), dim=2)
if norm is not None:
tensor = norm(tensor)
# Assert tensor_dim in [3, 4]
if tensor_dim == 4:
pass
else: # tensor_dim == 3:
tensor = tensor.squeeze(1)
return tensor, thw_shape
2.多尺度变换器网络(Multiscale Transformer Networks)
基于多头集中注意力(MHPA),本文创造了专门使用MHPA和MLP层进行视觉表征学习的多尺度变换器模型。在此之前,了解一下ViT模型。
2.1ViT
这里需要注意的是模型基于纯Transformer架构的,所以采用了Patch操作,详情参考ViT,所以图中的1,2,3,4是patch的大小,随着模型深入,patch是变大的,但是空间分辨(Patch分辨率)是降低的。
MViT
逐步增加信道维度,同时降低整个网络的时空分辨率(即序列长度)。MViT在早期层中具有精细的时空分辨率和低信道维度,而在后期层中,变为粗略的时空分辨率和高信道维度。MViT如表2所示,
需要注意之前提到了需要对数据进行patch操作,通过卷积实现(cube1),但是视频信号还有一个维度T,如图3所示,参数 代表cube1中计算卷积时对T维度的步长,是一个超参数,后续实验中出现的例如MViT-B 16*4 指的是输入16帧视频帧,取值4。
代表的是每个stage使用的Transformer个数,MHPA(D)代表的是其处理向量的维度为D,MLP(4D)表示Transformer block中全连接层隐藏单元数为4D,即输入维度的四倍。
Scale stages
尺度阶段定义为一组N个变换器块,在相同的尺度上跨信道和时空维度以相同的分辨率运行。在阶段转换时,信道维度上采样,而序列的长度下采样。
每个stage都应用了若干个Transformer blocks,图2所示的是每个Transformer block都采用了pooling的操作,所以为了保证每个stage中只对空间分辨率进行一次下采样,只在每个stage的第一个Transformer block对向量Q进行的操作,通stagetage其余Transformer block的向量Q进行的操作。对K,V的pooling操作不影响空间分辨率,所以论文中在同一个stage的所有Transformer block的K,V都进行了同样的pooling操作,即,
随着stage变深衰减,
之前说过通过全连接层进行通道数增加的操作,图3中并未显示的展示,其实在两个stages之间存在一个过渡操作,即每个stage的output sizes需要通过一个全连接层将维度D进行映射,只进行通道数增加的操作,然后送入下一个stage进行计算。
MViT更改Transformer结构之后的multi-head的个数如何确定?
论文中维度D对应一个head,即图中scale5使用的Transformer blocks的multi-head个数等于8,scale4中head个数等于4,以此类推
图4(a)是ViT的框架,MHA就是普通的multi-head-attention,可以发现,在ViT中是没有分层结构的,输出和输入形状是一样的,MViT采用MHPA引入了分层结构,提出了两个不同大小的模型MViT-B/ MViT-S,值得注意的是,两个模型的体量都比较小,不到7G的显存就可以运行MViT-B,对在校学生来说非常友好的,显存足够的情况下有很大的改进空间。