2.1 前向传播

        前向传播(forward propagation或forward pass)指的是:按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。假设输入样本是 x∈R(d), 并且我们的隐藏层不包括偏置项。 这里的中间变量是:

机器学习前向传播python代码解读 前向传播是什么_前端

        其中W(1)∈R(h×d)是隐藏层的权重参数。 将中间变量z∈R(h)通过激活函数ϕ后, 我们得到长度为h的隐藏激活向量:

机器学习前向传播python代码解读 前向传播是什么_机器学习前向传播python代码解读_02

隐藏变量h也是一个中间变量。 假设输出层的参数只有权重W(2)∈R(q×h),我们可以得到输出层变量,它是一个长度为q的向量: 

机器学习前向传播python代码解读 前向传播是什么_java_03

        假设损失函数为l,样本标签为y,我们可以计算单个数据样本的损失项,

机器学习前向传播python代码解读 前向传播是什么_java_04

        根据L2正则化的定义,给定超参数λ,正则化项为 :

机器学习前向传播python代码解读 前向传播是什么_数据库_05

        其中矩阵的Frobenius范数是将矩阵展平为向量后应用的L2范数。 最后,模型在给定数据样本上的正则化损失为:

机器学习前向传播python代码解读 前向传播是什么_反向传播_06

        根据上述描述得到下面的计算图( 其中正方形表示变量,圆圈表示操作符):

机器学习前向传播python代码解读 前向传播是什么_反向传播_07

 

2.2 反向传播

        反向传播(backward propagation或backpropagation)指的是计算神经网络参数梯度的方法。该方法根据微积分中的链式规则,按相反的顺序从输出层到输入层遍历网络。 该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。

        假设我们有函数Y=f(X)和Z=g(Y), 其中输入和输出X,Y,Z是任意形状的张量。 利用链式法则,我们可以计算Z关于X的导数:(使用prod运算符在执行必要的操作,如换位和交换输入位置,后将其参数相乘)

机器学习前向传播python代码解读 前向传播是什么_机器学习前向传播python代码解读_08

        如上图所示,反向传播的目的是计算梯度∂J/∂W(1)和 

机器学习前向传播python代码解读 前向传播是什么_反向传播_09

,应用链式法则,依次计算每个中间变量和参数的梯度。 计算的顺序与前向传播中执行的顺序相反,因为需要从计算图的结果开始,并朝着参数的方向努力。

        第一步是计算目标函数J=L+s相对于损失项L和正则项s的梯度。

机器学习前向传播python代码解读 前向传播是什么_机器学习前向传播python代码解读_10

 

        目标函数关于输出层变量o的梯度:

 

机器学习前向传播python代码解读 前向传播是什么_前端_11

        计算正则化项相对于两个参数的梯度: 

 

机器学习前向传播python代码解读 前向传播是什么_java_12

        计算最接近输出层的模型参数的梯度

机器学习前向传播python代码解读 前向传播是什么_机器学习前向传播python代码解读_13


机器学习前向传播python代码解读 前向传播是什么_前端_14

        关于隐藏层输出的梯度∂J/∂h∈Rh由下式给出:

机器学习前向传播python代码解读 前向传播是什么_数据库_15

 

        由于激活函数ϕ是按元素计算的, 计算中间变量z的梯度∂J/∂z∈Rh 需要使用按元素乘法运算符,我们用⊙表示 :

机器学习前向传播python代码解读 前向传播是什么_java_16

        最后,我们可以得到最接近输入层的模型参数的梯度 ∂J/∂W(1)∈Rh×d。 根据链式法则,我们得到:

机器学习前向传播python代码解读 前向传播是什么_java_17

 

2.3 小结

        在训练神经网络时,前向传播和反向传播相互依赖。 对于前向传播,我们沿着依赖的方向遍历计算图并计算其路径上的所有变量。 然后将这些用于反向传播,其中计算顺序与计算图的相反。 

        因此,在训练神经网络时,在初始化模型参数后, 我们交替使用前向传播和反向传播,利用反向传播给出的梯度来更新模型参数。 注意,反向传播重复利用前向传播中存储的中间值,以避免重复计算。 

        带来的影响之一是我们需要保留中间值,直到反向传播完成。 这也是训练比单纯的预测需要更多的内存(显存)的原因之一。 此外,这些中间值的大小与网络层的数量和批量的大小大致成正比。 因此,使用更大的批量来训练更深层次的网络更容易导致内存不足(out of memory)错误。