训练过程中的本质就是在最小化损失,在定义损失之后,接下来就是训练网络参数了,优化器可以让神经网络更快收敛到最小值。

本文介绍几种 tensorflow 常用的优化器函数。

 

1、GradientDescentOptimizer

梯度下降算法需要用到全部样本,训练速度比较慢。

tf.train.GradientDescentOptimizer(
learning_rate,
use_locking=False, 
name="GradientDescent"
)

 

2、AdagradOptimizer

自适应学习率,加入一个正则化项 

sag 优化器 python 优化器函数_sag 优化器 python

,对学习率进行约束,前期学习率小的时候,正则化项大,能够放大梯度;后期,学习率大的时候,正则化项大,可以减少梯度,适合处理稀疏数据。缺点:依赖于全局学习率。

tf.train.AdagradOptimizer(
learning_rate, # 学习率
initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value, # 累计初始值
use_locking=False, 
name="Adagrad"
)

 

3、AdadeltaOptimizer

Adadelta和Adagrad一样,都是自适应学习率,它是对Adagrad的改进,在计算上有所区别,Adagrad累加梯度的平方,依赖于全部学习率,而Adadelta加固定大小的值,不依赖于全局学习率。

tf.train.AdadeltaOptimizer(
learning_rate, # 学习率
rho=FLAGS.adadelta_rho, # 衰减率
epsilon=FLAGS.opt_epsilon, # 用于更好的调节梯度更新的常量
use_locking=False, # 若为True,锁住更新操作
name="Adadelta" # 操作名
)

 

4、RMSPropOptimizer

它是Adagrad的改进、Adadelta的变体,仍然依赖于全局学习率,效果位于两者之间,对于RNN效果较好。

tf.train.RMSPropOptimizer(
learning_rate,
decay=FLAGS.rmsprop_decay, # 梯度的系数
momentum=FLAGS.rmsprop_momentum, # 动量
epsilon=FLAGS.opt_epsilon, # 用于更好的调节梯度更新的常量
use_locking=False,
centered=False, # 如果为True,则通过梯度的估计方差对梯度进行归一化
name="RMSProp"
)

 

5、MomentumOptimizer

就像物理上的动量一样,梯度大的时候,动量大,梯度小的时候,动量也会变小,能够更加平稳、快速地冲向局部最小点。

tf.train.MomentumOptimizer(
learning_rate, # 学习率
momentum=FLAGS.momentum, # 动量
use_locking=False,
name='Momentum',
use_nesterov=False # 若为True,则使用Nesterov动量
)

 

6、AdamOptimizer

可以看作是带有Momentum的RMSProp,可以将学习率控制在一定范围内,参数较平稳。

tf.train.AdamOptimizer(
learning_rate, # 学习率
beta1=FLAGS.adam_beta1, # 一阶矩估计衰减率
beta2=FLAGS.adam_beta2, # 二阶矩估计衰减率
epsilon=FLAGS.opt_epsilon, # 用于更好的调节梯度更新的常量
use_locking=False,
name="Adam"
)

 

7、FtrlOptimizer

FTRL 就是正则项为0的SGD算法。

tf.train.FtrlOptimizer(
learning_rate, # 学习率
learning_rate_power=FLAGS.ftrl_learning_rate_power, # 控制训练期间学习率衰减方式
initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value, # 累计器初始值
l1_regularization_strength=FLAGS.ftrl_l1, # L1正则化系数
l2_regularization_strength=FLAGS.ftrl_l2, # L2正则化系数
use_locking=False,
name="Ftrl",
accum_name=None,
linear_name=None,
l2_shrinkage_regularization_strength=0.0 # 惩罚项
)