注意力机制 CVPR2021 Coordinate Attention || Pytorch代码实现
- 即插即用!
- 一、Coordinate Attention 简介
- 二、使用步骤
- 1.结构图
- 1.pytorch 代码
即插即用!
提示:这里可以添加本文要记录的大概内容:
CoordAttention简单灵活且高效,可以插入经典的轻量级网络在几乎不带来额外计算开销的前提下,提升网络的精度。实验表明,CoordAttention不仅仅对于分类任务有不错的提高,对目标检测、实例分割这类密集预测的任务,效果提升更加明显。
论文名称
Hou Q, Zhou D, Feng J. Coordinate attention for efficient mobile network design[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021: 13713-13722.
论文地址链接
代码链接
https://github.com/Andrew- Qibin/CoordAttention.
提示:以下是本篇文章正文内容,下面案例可供参考
一、Coordinate Attention 简介
尽管通道注意力(例如,Squeeze-and-Excitation attention)对于提升模型的性能具有有效性,但通常忽略了位置信息,这对于生成空间选择性注意力图很重要。在本文中,通过将位置信息嵌入到通道注意力中来为移动网络提出一种新的注意力机制,称之为“坐标注意力”(Coordinate Attention)。**坐标注意力将通道注意力分别沿两个空间方向分解为两个一维特征编码,并沿空间方向聚合特征。通过这种方式,可以沿一个空间方向捕获远程依赖关系,同时可以沿另一个空间方向保留精确的位置信息。然后将得到的特征图单独编码成一对方向感知和位置敏感的注意力图,这些图可以互补地应用于输入特征图以增强感兴趣对象的表示。**坐标注意力很简单,可以灵活地插入到经典的移动网络中,例如 MobileNetV2、MobileNeXt 和 EfficientNet,几乎没有计算开销。大量实验表明,坐标注意力不仅有利于 ImageNet 分类,而且更有趣的是,在下游任务中表现更好,例如对象检测和语义分割。
二、使用步骤
1.结构图
1.pytorch 代码
代码如下(示例):
class CA_Block(nn.Module):
def __init__(self, channel, h, w, reduction=16):
super(CA_Block, self).__init__()
self.h = h
self.w = w
#这里使用了广义平均池化
self.gempool_x = GeneralizedMeanPooling(norm=3,output_size=(h, 1))
self.gempool_y = GeneralizedMeanPooling(norm=3, output_size=(1, w))
#原文使用全局平均池化,可自行更改
self.avg_pool_x = nn.AdaptiveAvgPool2d((h, 1))
self.avg_pool_y = nn.AdaptiveAvgPool2d((1, w))
self.conv_1x1 = nn.Conv2d(in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1,
bias=False)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(channel // reduction)
self.F_h = nn.Conv2d(in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1,
bias=False)
self.F_w = nn.Conv2d(in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1,
bias=False)
self.sigmoid_h = nn.Sigmoid()
self.sigmoid_w = nn.Sigmoid()
def forward(self, x):
x_h = self.gempool_x(x).permute(0, 1, 3, 2)
x_w = self.gempool_y(x)
x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w),
x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([self.h, self.w],
s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))
s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))
out_put = x * s_h.expand_as(x) * s_w.expand_as(x)
return output
if __name__ == '__main__':
x = torch.randn(1, 512, 56, 56) # b, c, h, w
ca_model = CA_Block(channel=512, h=56, w=56)
y = ca_model(x)
print(y.shape)