0.引言

目前,向神经网络中引入注意力机制以提高网络表达效果的技术已经日益成熟,尤其是数字图像处理领域,涉及到注意力机制的文章更是数不胜数。我使用tensorflow2.3keras API
搭建注意力模块SE-Net的时候,需要对张量进行尺寸变化,由于我对tensorflow2的模型训练机制不够深入了解,出现了关于None占位符的问题,下面基于SE-Net进行详细介绍。

1.SE-Net简介

SE-Net是一款非常经典的注意力机制,模块结构如下:

tensorflow多头自注意力机制代码实现_tensorflow


该模块作用是:通过学习的方式来自动获取到每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征并抑制对当前任务用处不大的特征,这又叫做“特征重标定”策略。此处参考()

简单来说:就是把每个通道的权重也作为一组超参数送进网络中去学习,这组超参数是由Fsq(·)得到的,如图1x1xC2向量即为通道权重向量,再经过一系列处理,与原张量进行相乘(对应元素相乘),即对C2个通道种每个通道赋予权重如图所示:右边五颜六色那个张量就是经过通道加权的张量。

总结:SE-Net注意力模块作用是:进行张量通道权重的学习,刻画每个通道的重要性。

2. 代码实现

使用tensorflow2.3中keras框架进行搭建SE-Net block,代码如下:

class SE_Att(Model):
    def __init__(self, C): 
        super(SE_Att, self).__init__()
        self.ratio = 16
        self.un = C
        self.p = GlobalAveragePooling2D(data_format='channels_last')
        self.f1 = Dense(units=self.un // self.ratio, use_bias=False, activation='relu')
        self.f2 = Dense(units=self.un, use_bias=False, activation='sigmoid')

    def call(self, x):
        # x:N H W C
        N, H, W, C = x.shape
        context = self.p(x)  # N C
        # 全局池化[N,1,1,C]
        context = self.f1(context)
        context = self.f2(context)
        context = tf.reshape(context, shape=(None, 1, 1, C))  # 语句0 用-1表示None
        y = x * context
        return y

3.错误分析

错误提示:TypeError: Failed to convert object of type <class ‘tuple’> to Tensor. Contents: (None, 1, 1, 128). Consider casting elements to a supported type.

出现问题的地方在上述代码语句0,reshape中形参shape不能接受含有None的元组。但是None在tensorflow中表示一个占位符,指定该维度下为任意正值,常常用于图像处理时batch的占位,例如张量(None,64,64,3)表示任意张64x64x3大小的图片
我们使用**-1来替换 tf.reshape(context, shape=(None, 1, 1, C))中的None**即可。
所以-1实际上就表示了张量在此维度出可以接受任何值。

4.代码纠错

语句0处改为:

context = tf.reshape(context, shape=(-1, 1, 1, C))

修改后,模块可以在网络中正常运行,监控每时刻的输出输入张量大小,均与理论一致。