tf.train.ExponentialMovingAverage

函数定义

tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,他使用指数衰减来计算变量的移动平均值。

tf.train.ExponentialMovingAverage.__init__(self, decay, num_updates=None, zero_debias=False, name="ExponentialMovingAverage"):


  • decay是衰减率在创建ExponentialMovingAverage对象时,需指定衰减率(decay),用于控制模型的更新速度。影子变量的初始值与训练变量的初始值相同。当运行变量更新时,每个影子变量都会更新为:
    shadowvariable=decay∗shadowvariable+(1−decay)∗variable



  • num_updates是ExponentialMovingAverage提供用来动态设置decay的参数,当初始化时提供了参数,即不为none时,每次的衰减率是:
    min{decay,(1+num_updates)/(10+num_updates)}



  • apply()方法添加了训练变量的影子副本,并保持了其影子副本中训练变量的移动平均值操作。在每次训练之后调用此操作,更新移动平均值。



  • average()和average_name()方法可以获取影子变量及其名称。



  • decay设置为接近1的值比较合理,通常为:0.999,0.9999等



实例代码如下:

v1 = tf.Variable(0, dtype=tf.float32)   # 定义一个变量,初始值为0 step = tf.Variable(0, trainable=False)  # step为迭代轮数变量,控制衰减率 ema = tf.train.ExponentialMovingAverage(0.99, step)  # 初始设定衰减率为0.99 maintain_averages_op = ema.apply([v1])                 # 更新列表中的变量 with tf.Session() as sess:     init_op = tf.global_variables_initializer()        # 初始化所有变量 sess.run(init_op) print(sess.run([v1, ema.average(v1)]))                # 输出初始化后变量v1的值和v1的滑动平均值 sess.run(tf.assign(v1, 5))                            # 更新v1的值 sess.run(maintain_averages_op)                        # 更新v1的滑动平均值 print(sess.run([v1, ema.average(v1)])) sess.run(tf.assign(step, 10000))                      # 更新迭代轮转数step sess.run(tf.assign(v1, 10)) sess.run(maintain_averages_op) print(sess.run([v1, ema.average(v1)]))                                                       # 再次更新滑动平均值, sess.run(maintain_averages_op) print(sess.run([v1, ema.average(v1)]))                                                       # 更新v1的值为15 sess.run(tf.assign(v1, 15))  sess.run(maintain_averages_op) print(sess.run([v1, ema.average(v1)])) # # [0.0, 0.0] # [5.0, 4.5] # [10.0, 4.5549998] # [10.0, 4.6094499] # [15.0, 4.7133551]

计算步骤如下:

滑动平均模型的作用是提高测试值上的健壮性。那它是如何实现这个功能的呢?其实滑动平均模型的原理就是一阶滞后滤波法,其表达式如下:

上面的实例

**********************************************

输入 0.0

输出计算:

decay = min(0.99,(1+0)/(10+0)) =0.1

输出 = 0.1 * 0+(1-0.1)*0 = 0

**********************************************

输入 5.0

输出计算:

decay = min(0.99,(1+0)/(10+0)) =0.1

输出 = 0.1 * 0+(1-0.1)*5= 4.5

**********************************************

输入 10.0

输出计算:

decay = min(0.99,(1+10000)/(10+10000)) =0.99

输出 = 0.99 * 4.5+(1-0.99)*10= 4.555

**********************************************

输入 10.0

输出计算:

decay = min(0.99,(1+10000)/(10+10000)) =0.99

输出 = 0.99 * 4.555+(1-0.99)*15= 4.60945

**********************************************

输入 15.0

输出计算:

decay = min(0.99,(1+10000)/(10+10000)) =0.99

输出 = 0.99 * 4.60945+(1-0.99)*15= 4.713355

**********************************************