RNN的时间反向传播原理

  本节将介绍循环神经网络中梯度的计算和存储方法,即通过时间反向传播(back-propagation through time)。正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式求导法则应用反向传播计算并存储梯度。

1. 定义模型

简单起见,我们考虑一个无偏置项的循环神经网络,且激活函数简化为恒等映射 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸。设时间步 RNN反向传播梯度爆炸_反向传播_02 的输入为单样本 RNN反向传播梯度爆炸_循环神经网络_03,即不考虑 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_04,标签为 RNN反向传播梯度爆炸_反向传播_05,那么隐藏状态 RNN反向传播梯度爆炸_循环神经网络_06 的计算表达式为:
RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_07
其中,RNN反向传播梯度爆炸_深度学习_08RNN反向传播梯度爆炸_深度学习_09 是隐藏层权重系数。设输出层权重系数 RNN反向传播梯度爆炸_反向传播_10,时间步 RNN反向传播梯度爆炸_反向传播_02 的输出层变量 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_12 计算为:
RNN反向传播梯度爆炸_深度学习_13
设时间步 RNN反向传播梯度爆炸_反向传播_02 的损失为 RNN反向传播梯度爆炸_深度学习_15。时间步数为 RNN反向传播梯度爆炸_依赖关系_16 的损失函数 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_17 定义为:
RNN反向传播梯度爆炸_依赖关系_18
我们将 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_17 称为有关给定时间步 RNN反向传播梯度爆炸_依赖关系_16

2.模型计算图

为了可视化循环神经网络中变量模型参数计算中的依赖关系,我们可以绘制模型计算图,假设时间步数为 3 的循环神经网络模型计算中的依赖关系,方框代表变量(无阴影)和模型参数(有阴影),圆圈代表运算符。



RNN反向传播梯度爆炸_循环神经网络_21


如上图所示,所有时间步的模型参数都是共享的,而且相邻时间步之间具有依赖性,即:时间步 2 的计算依赖于时间步 1 得到的隐藏状态,时间步 3 的计算又依赖于时间步 2 得到的隐藏状态。例如:时间步 RNN反向传播梯度爆炸_深度学习_22 的隐藏状态 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_23 的计算依赖模型参数 RNN反向传播梯度爆炸_反向传播_24RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_25、上一时间步的隐藏状态 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_26 以及当前时间步输入 RNN反向传播梯度爆炸_依赖关系_27。即:



RNN反向传播梯度爆炸_深度学习_28


有时候,一些教材或者书籍也被习惯性写成转置的形式:



RNN反向传播梯度爆炸_依赖关系_29


这里的 RNN反向传播梯度爆炸_依赖关系_30RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_31RNN反向传播梯度爆炸_深度学习_32。有的文章里面可能会使用下面这种表达形式,大家只要记住转置矩阵满足的运算规律 RNN反向传播梯度爆炸_反向传播_33

3. 方法

我们从上图中可以直观的看出模型的参数就是 RNN反向传播梯度爆炸_反向传播_24RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_25RNN反向传播梯度爆炸_深度学习_36 。所以训练模型反向传播更新参数时,需要计算模型参数的梯度 RNN反向传播梯度爆炸_循环神经网络_37RNN反向传播梯度爆炸_依赖关系_38RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_39。根据图中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。

首先,目标函数 RNN反向传播梯度爆炸_依赖关系_40 有关各时间步输出层变量的梯度 RNN反向传播梯度爆炸_依赖关系_41 很容易计算:
RNN反向传播梯度爆炸_依赖关系_42
下面,我们可以计算目标函数关于模型参数 RNN反向传播梯度爆炸_深度学习_36 的梯度 RNN反向传播梯度爆炸_依赖关系_44。根据模型计算图可知,RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_17 通过 RNN反向传播梯度爆炸_深度学习_46 依赖 RNN反向传播梯度爆炸_深度学习_36。依据链式法则,
RNN反向传播梯度爆炸_深度学习_48
又因为 RNN反向传播梯度爆炸_深度学习_49,于是就有 RNN反向传播梯度爆炸_反向传播_50,所以
RNN反向传播梯度爆炸_循环神经网络_51
这样写好像不是那么严谨,所以如果从一个数学家的角度出发,他可能考虑将上式写成累加的形式,也就是如下的形式:
RNN反向传播梯度爆炸_深度学习_52
到此为止,我们已经求出了一个模型参数 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_39 的梯度更新公式了。那我们继续来推导剩余的两个模型参数的梯度更新公式吧!前方高能,可能在没推导之前,读者应该意识到了这两个参数并不像前面那个参数的计算梯度那样好求,因为涉及到了前后时间步的隐藏状态之间也存在依赖关系。所以我们先要建立递推关系式

在建立递推关系式之前,我们先找到递归的出口,那就是最后的时间步 RNN反向传播梯度爆炸_依赖关系_54

目标函数 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_17 只通过 RNN反向传播梯度爆炸_深度学习_56 依赖最终时间步 RNN反向传播梯度爆炸_依赖关系_16 的隐藏状态 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_58。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度 RNN反向传播梯度爆炸_反向传播_59。依据链式法则,就可以得到:
RNN反向传播梯度爆炸_反向传播_60
又因为 RNN反向传播梯度爆炸_依赖关系_61,所以就有 RNN反向传播梯度爆炸_反向传播_62
RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_63
接下来对于时间步 RNN反向传播梯度爆炸_深度学习_64。因为 RNN反向传播梯度爆炸_深度学习_65,而且在前向传播的时候,RNN反向传播梯度爆炸_循环神经网络_66,所以我们说 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_17 通过 RNN反向传播梯度爆炸_循环神经网络_68RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_69 依赖 RNN反向传播梯度爆炸_循环神经网络_70。依据链式法则,目标函数有关时间步 RNN反向传播梯度爆炸_依赖关系_71 的隐藏状态的梯度RNN反向传播梯度爆炸_反向传播_72 需要按照时间步从大到小依次计算:
RNN反向传播梯度爆炸_依赖关系_73
根据上面的递推公式和递推关系出口,对于任意时间步 RNN反向传播梯度爆炸_反向传播_74,我们可以得到目标函数有关隐藏状态梯度的通项公式(这一步如何得到,我们不加证明的直接给出):
RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_75
由上式中的指数项可见,当时间步数 RNN反向传播梯度爆炸_依赖关系_16 较大或者时间步 RNN反向传播梯度爆炸_反向传播_02 较小时,目标函数有关隐藏状态的梯度较容易出现衰减和爆炸。这也会影响其他包含 RNN反向传播梯度爆炸_深度学习_78 项的梯度,例如隐藏层中模型参数的梯度 RNN反向传播梯度爆炸_循环神经网络_37RNN反向传播梯度爆炸_依赖关系_38RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_17 通过 RNN反向传播梯度爆炸_依赖关系_82

RNN反向传播梯度爆炸_依赖关系_83

每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算,例如,由于隐藏状态梯度 RNN反向传播梯度爆炸_深度学习_78 被计算和存储,之后的模型参数梯度 RNN反向传播梯度爆炸_反向传播_85RNN反向传播梯度爆炸_依赖关系_86 的计算可以直接读取 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_87

此外,反向传播中的梯度计算可能会依赖变量的当前值,它们正是通过正向传播计算出来的,举例来说,参数梯度 RNN反向传播梯度爆炸_RNN反向传播梯度爆炸_88 的计算需要依赖隐藏状态在时间步 RNN反向传播梯度爆炸_循环神经网络_89的当前值 RNN反向传播梯度爆炸_循环神经网络_70

小结:
本节主要介绍了循环神经网络中通过时间反向传播来更新模型参数的推导过程,要求不在于能自己手动推导每一步,而在于理解循环神经网络的梯度为什么会出现衰减或者爆炸,而不是人云亦云或道听途说,后来引入的梯度裁剪、具有门控单元的循环神经网络都是基于这个原因。所以推导不要求,而重在理解问题本身,因为现在也不需要由研究人员手动去计算每一步了,都交给计算机或现有的模块来完成了。

  1. 通过时间反向传播是反向传播在循环神经网络中的具体应用;
  2. 当总的时间步数较大或者当前时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸;