函数定义
tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,他使用指数衰减来计算变量的移动平均值。
tf.train.ExponentialMovingAverage.__init__(self, decay, num_updates=None, zero_debias=False, name="ExponentialMovingAverage"):
decay是衰减率在创建ExponentialMovingAverage对象时,需指定衰减率(decay),用于控制模型的更新速度。影子变量的初始值与训练变量的初始值相同。当运行变量更新时,每个影子变量都会更新为:
shadowvariable=decay∗shadowvariable+(1−decay)∗variable
num_updates是ExponentialMovingAverage提供用来动态设置decay的参数,当初始化时提供了参数,即不为none时,每次的衰减率是:
min{decay,(1+num_updates)/(10+num_updates)}
apply()方法添加了训练变量的影子副本,并保持了其影子副本中训练变量的移动平均值操作。在每次训练之后调用此操作,更新移动平均值。
average()和average_name()方法可以获取影子变量及其名称。
decay设置为接近1的值比较合理,通常为:0.999,0.9999等
实例代码如下:
v1 = tf.Variable(0, dtype=tf.float32) # 定义一个变量,初始值为0 step = tf.Variable(0, trainable=False) # step为迭代轮数变量,控制衰减率 ema = tf.train.ExponentialMovingAverage(0.99, step) # 初始设定衰减率为0.99 maintain_averages_op = ema.apply([v1]) # 更新列表中的变量 with tf.Session() as sess: init_op = tf.global_variables_initializer() # 初始化所有变量 sess.run(init_op) print(sess.run([v1, ema.average(v1)])) # 输出初始化后变量v1的值和v1的滑动平均值 sess.run(tf.assign(v1, 5)) # 更新v1的值 sess.run(maintain_averages_op) # 更新v1的滑动平均值 print(sess.run([v1, ema.average(v1)])) sess.run(tf.assign(step, 10000)) # 更新迭代轮转数step sess.run(tf.assign(v1, 10)) sess.run(maintain_averages_op) print(sess.run([v1, ema.average(v1)])) # 再次更新滑动平均值, sess.run(maintain_averages_op) print(sess.run([v1, ema.average(v1)])) # 更新v1的值为15 sess.run(tf.assign(v1, 15)) sess.run(maintain_averages_op) print(sess.run([v1, ema.average(v1)])) # # [0.0, 0.0] # [5.0, 4.5] # [10.0, 4.5549998] # [10.0, 4.6094499] # [15.0, 4.7133551]
计算步骤如下:
滑动平均模型的作用是提高测试值上的健壮性。那它是如何实现这个功能的呢?其实滑动平均模型的原理就是一阶滞后滤波法,其表达式如下:
上面的实例
**********************************************
输入 0.0
输出计算:
decay = min(0.99,(1+0)/(10+0)) =0.1
输出 = 0.1 * 0+(1-0.1)*0 = 0
**********************************************
输入 5.0
输出计算:
decay = min(0.99,(1+0)/(10+0)) =0.1
输出 = 0.1 * 0+(1-0.1)*5= 4.5
**********************************************
输入 10.0
输出计算:
decay = min(0.99,(1+10000)/(10+10000)) =0.99
输出 = 0.99 * 4.5+(1-0.99)*10= 4.555
**********************************************
输入 10.0
输出计算:
decay = min(0.99,(1+10000)/(10+10000)) =0.99
输出 = 0.99 * 4.555+(1-0.99)*15= 4.60945
**********************************************
输入 15.0
输出计算:
decay = min(0.99,(1+10000)/(10+10000)) =0.99
输出 = 0.99 * 4.60945+(1-0.99)*15= 4.713355
**********************************************