DANet Attention
论文链接r:Dual Attention Network for Scene Segmentation
模型结构图:
论文主要内容
在论文中采用的backbone是ResNet,50或者101,是融合空洞卷积核并删除了池化层的ResNet。之后分两路都先进过一个卷积层,然后分别送到位置注意力模块和通道注意力模块中去。
Backbone:该模型的主干网络采用了ResNet系列的骨干模型,在此基础上,作者引入了平行的特征提取分支,并在分类网络中使用全局平均池化技术,兼顾全局和局部信息,提高模型特征提取的效果。
注意力模块:该模型使用了两个注意力模块,即位置注意力模块和通道注意力模块。
① 位置注意力模块注意力模块:通过通道元素的相关性计算来捕获多级空间元素上下文信息,提高对场景分割的准确性;
② 通道注意力模块:提升对输入特征图的空间元素进行相似性计算,获得通道的关联矩阵。进行选择性地选着提取感兴趣区域。
特征融合模块:该模型引入了一个特征融合模块,用于将位置注意力模块注意力模块获得特征图和通道注意力模块的特征图结果相加融合在一起。
import numpy as np
import torch
from torch import nn
from torch.nn import init
from model.attention.SelfAttention import ScaledDotProductAttention
from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention
#PositionAttentionModule(位置注意力模块)
#通过卷积层将输入张量的各个通道在空间维度上进行混合,得到新的特征图。
class PositionAttentionModule(nn.Module):
def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
super().__init__()
self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)
def forward(self,x):
bs,c,h,w=x.shape
y=self.cnn(x)
y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,c
#在注意力机制中,输入的Q、K、V都相同,即均为上一步得到的新特征图。
y=self.pa(y,y,y) #bs,h*w,c
return y
#ChannelAttentionModule(通道注意力模块)
#重排成三维张量后,输入到的注意力机制中,Q、K、V分别是上一步输入的各个通道。注意力机制将对应位置上的Q和K相乘并除以一个归一化因子,再与V相乘,最后##得到加权后的输出特征,形状与输入通道数相同(batch_size,channels,h*w)。最后,将加权特征重新排列为形状为(batch_size,channels,h,w)的##四维张量。将位置注意力模块和通道注意力模块得到的输出特征进行相加操作,即可得到双注意力机制模块的最终输出结果。
class ChannelAttentionModule(nn.Module):
def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
super().__init__()
self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
self.pa=SimplifiedScaledDotProductAttention(H*W,h=1)
def forward(self,x):
bs,c,h,w=x.shape
y=self.cnn(x)
y=y.view(bs,c,-1) #bs,c,h*w
y=self.pa(y,y,y) #bs,c,h*w
return y
#双注意力机制的模块(DAModule),用于增强图像特征的表示能力
class DAModule(nn.Module):
def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
super().__init__()
self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7)
self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7)
def forward(self,input):
bs,c,h,w=input.shape
p_out=self.position_attention_module(input)
c_out=self.channel_attention_module(input)
p_out=p_out.permute(0,2,1).view(bs,c,h,w)
c_out=c_out.view(bs,c,h,w)
return p_out+c_out
if __name__ == '__main__':
input=torch.randn(50,512,7,7)
danet=DAModule(d_model=512,kernel_size=3,H=7,W=7)
print(danet(input).shape)