adam梯度下降 代码 adam算法和梯度下降算法_adam梯度下降 代码

我们在机器学习的过程中,当我们构建好我们的模型后要对输出构建损失函数。然后要不断的减小损失函数的值来不断更新优化我们模型中的参数。

那么如何优化我们的参数呢?

梯度下降法:

adam梯度下降 代码 adam算法和梯度下降算法_数据_02

对1到M这些给出的数据计算损失函数之和的均值

adam梯度下降 代码 adam算法和梯度下降算法_python sklearn 梯度下降法_03

求导

adam梯度下降 代码 adam算法和梯度下降算法_python sklearn 梯度下降法_04

更新参数,a为学习率(用于决定我们学习的步长)

通俗一点将过程就相当于:

1 遍历我们所有的数据(求损失函数均值)

2 环顾四方,寻找一个最优(损失函数下降最快)的方向(求导)

3 朝着损失函数下降最快的地方迈出大小为a的一步(更新参数)

但是这样有哪些不好的地方?

1 遍历所有数据,更新参数缓慢,计算量大,模型训练慢

2 学习率的选择对于是否可以找到最优解很重要

那么我们首先克服第一个问题:遍历所有数据,更新参数缓慢,计算量大,模型训练慢

批量(随机)梯度下降:

简单来说,为了克服我们上述的问题,我们找到问题的根本所在为遍历所有的数据后进行更新,那么,我们选择一部分数据进行求和求导后进行更新,是不是可以解决这个问题?

答案是可以的,有数学推导证明了,小批量的数据的求和为所有数据求和的无偏估计,怎么解释呢?大概意思就是选择一部分数据可以估计个差不太多,将所有的数据按照批次更新完之后是没有偏差的。

但是,这样的作法总会有一点问题,首先,虽然小批次数据的结果是正常结果的无偏估计,但是,虽然期望相同,但是会产生方差,方差的产生会对于我们的收敛有着一定的影响,有可能在最差的情况下会导致函数无法收敛,或者产生局部最优值。并且在更新的后期容易产生不稳定的震荡。

对于上述的问题我们可以采取的一个方法就是在更新的后期手动的缩小学习率,这样对于函数的收敛有着比较好的作用。

随机梯度下降:

批量随机下降的一个特殊情况,每次计算一个数据点后进行更新,参数更新更快,不稳定性更大。

在说明动量法之前我们先讲一个预备知识:指数加权平均EMA

adam梯度下降 代码 adam算法和梯度下降算法_数据_05

首先我们拿到这样一个点状图,我们如何看待他的变化趋势?

其中一个方法就是计算当天以及其前面N天数值的加权平均值作为当天的数值来对点状数据进行平滑处理。

具体的公式为:

adam梯度下降 代码 adam算法和梯度下降算法_数据_06

这里的Vt代表当天估计数值,Vt-1待变前一天的估计数值,后面的值代表当天的确定值,β是常量,1/(1-β)便是我们估计多少天的指数加权平均值。

举个例子,当β为0.98时,我们计算的估计值就是当天以及其前1/1-0.98=20天的指数加权平均值。β愈接近1,我们估计的天数越多,我们的估计曲线越平缓。

那么我们就可以用这种方式解决下面这种情况:

adam梯度下降 代码 adam算法和梯度下降算法_adam算法_07

图为损失函数的等高线

在图中,我们的参数在竖直方向上来回抖动,水平方向上缓慢前行。

因为竖直方向的抖动导致我们的学习率不能选择较大的值,因为很可能因为学习率过大而使得整个函数无法收敛。

所以我们希望整个过程“平滑一些”,这样我们就可以使用我们的指数加权平均,然后整个过程就会变成这样:

adam梯度下降 代码 adam算法和梯度下降算法_数据_08

如图中红线所示,这样我门就可以选择较大的学习率加快我们的训练过程,这个方法也叫做动量法。

上面我们解决了第一个问题与震荡过大的问题,那么我们的第二个问题如何解决呢?

RMSporp法:

他的核心思想就是解决虽然在更新初期,学习率高加快我们的模型训练,但是后期的学习率如果还是比较高的情况下我们很可能无法收敛,那么我们就要在训练速度和训练精度中做出取舍,为了一举两得,我们很简单的想法就是,在训练的前期使用比较高的学习率,但是在训练的后期使用较低的学习率,他的实现数学公式如下:

adam梯度下降 代码 adam算法和梯度下降算法_adam算法_09

我们可以看到,RMSProp算法对梯度计算了微分平方加权平均数。这种做法有利于消除了摆动幅度大的方向,用来修正摆动幅度,使得各个维度的摆动幅度都较小。另一方面也使得网络函数收敛更快。最后在训练次数较多后也会减小学习率的值,方便收敛。

Adam算法:

其实Adam算法综合了动量法和RMSPorp算法,不仅可以进一步缩小更新的抖动,并且平衡各个参数之前的更新速度,加快收敛,最后保证可以收敛。

adam梯度下降 代码 adam算法和梯度下降算法_adam梯度下降 代码_10

adam梯度下降 代码 adam算法和梯度下降算法_adam梯度下降 代码_11

adam梯度下降 代码 adam算法和梯度下降算法_adam算法_12