文章目录

  • 相关文章
  • 前言
  • 一、反向传播算法
  • 1.1 什么是反向传播算法?
  • 1.2 更泛化的例子
  • 二、计算图
  • 2.1 什么是传播图?
  • 2.2 一个简单的例子
  • 总结



前言

  本文总结了关于反向传播算法以及计算图的相关内容以及原理,并通过举例说明整个运算过程。下面就是本篇博客的全部内容!


一、反向传播算法

1.1 什么是反向传播算法?

  假设现在有如下图的一个过程:

反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java


图1:只有一个神经元的前向传播过程


反向传播算法代码 java 反向传播算法流程图_反向传播算法_02去拟合数据,初始输入值为反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_03,参数为反向传播算法代码 java 反向传播算法流程图_算法_04,得到预测值反向传播算法代码 java 反向传播算法流程图_反向传播算法_05,其真实值为反向传播算法代码 java 反向传播算法流程图_算法_06,那么可以得到损失函数为:
反向传播算法代码 java 反向传播算法流程图_人工智能_07
  现在我们的目的是通过更新参数反向传播算法代码 java 反向传播算法流程图_算法_04使得损失值最小,为了更直观的演示,我们先给参数赋值,其中反向传播算法代码 java 反向传播算法流程图_计算图_09表示学习率:

  • 反向传播算法代码 java 反向传播算法流程图_计算图_10
  • 反向传播算法代码 java 反向传播算法流程图_算法_11
  • 反向传播算法代码 java 反向传播算法流程图_计算图_12
  • 反向传播算法代码 java 反向传播算法流程图_算法_13

反向传播算法代码 java 反向传播算法流程图_人工智能_14,那么通过赋的初始值,可以得到:

  • 反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_15
  • 反向传播算法代码 java 反向传播算法流程图_计算图_16

反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17对参数反向传播算法代码 java 反向传播算法流程图_算法_18的偏导数:
反向传播算法代码 java 反向传播算法流程图_反向传播算法_19
  然后更新参数反向传播算法代码 java 反向传播算法流程图_算法_18
反向传播算法代码 java 反向传播算法流程图_计算图_21
  另一个参数反向传播算法代码 java 反向传播算法流程图_算法_22也是一样的过程,首先计算损失函数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17对参数反向传播算法代码 java 反向传播算法流程图_算法_22的偏导数:
反向传播算法代码 java 反向传播算法流程图_算法_25
  然后更新参数反向传播算法代码 java 反向传播算法流程图_算法_22
反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_27
  上面的过程看似很简单,但是如何求得反向传播算法代码 java 反向传播算法流程图_反向传播算法_28反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_29呢?以参数反向传播算法代码 java 反向传播算法流程图_算法_18为例,很明显损失函数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17是关于参数反向传播算法代码 java 反向传播算法流程图_反向传播算法_05的函数,而反向传播算法代码 java 反向传播算法流程图_反向传播算法_05又是关于参数反向传播算法代码 java 反向传播算法流程图_算法_18的偏导,所以根据高等数学中学过的链式求导法则可以得到:
反向传播算法代码 java 反向传播算法流程图_算法_35
  所以首先我们要求出损失函数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17对参数反向传播算法代码 java 反向传播算法流程图_反向传播算法_05的偏导:
反向传播算法代码 java 反向传播算法流程图_算法_38
  然后再求出反向传播算法代码 java 反向传播算法流程图_反向传播算法_05反向传播算法代码 java 反向传播算法流程图_算法_18的偏导:
反向传播算法代码 java 反向传播算法流程图_反向传播算法_41
  再将其代入可以得到:
反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_42
  同理我们也可以求出反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_29
反向传播算法代码 java 反向传播算法流程图_反向传播算法_44
  然后再将预设好的数据值代入,就可以得到损失函数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17关于参数反向传播算法代码 java 反向传播算法流程图_算法_04的偏导值:
反向传播算法代码 java 反向传播算法流程图_反向传播算法_47

反向传播算法代码 java 反向传播算法流程图_反向传播算法_48

反向传播算法代码 java 反向传播算法流程图_算法_04的公式,就可以根据梯度下降算法更新参数得到:

  • 反向传播算法代码 java 反向传播算法流程图_计算图_50
  • 反向传播算法代码 java 反向传播算法流程图_人工智能_51
  • 反向传播算法代码 java 反向传播算法流程图_人工智能_52
  • 反向传播算法代码 java 反向传播算法流程图_算法_53

反向传播算法代码 java 反向传播算法流程图_算法_04,就要从后向前依次求偏导才能得到结果,这种利用梯度从后向前更新参数的方法,也被称为反向传播算法(Back-Propagation,BR)。

1.2 更泛化的例子

  刚才介绍的例子参数比较少,计算过程也比较简单,那如果我们的情况更泛化,也更复杂呢?假设有如下图的一种情况:

反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_55


图2:有两个神经元的前向传播过程


反向传播算法代码 java 反向传播算法流程图_算法_56,第一次计算的结果反向传播算法代码 java 反向传播算法流程图_反向传播算法_57又当作参数传入下一个神经元,经过与参数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_58的计算得到最终的结果反向传播算法代码 java 反向传播算法流程图_算法_59,然后用反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_60表示经过两次拟合后的结果与真实值的误差。

反向传播算法代码 java 反向传播算法流程图_算法_56最重要的就是求出反向传播算法代码 java 反向传播算法流程图_计算图_62反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_63,我们先计算反向传播算法代码 java 反向传播算法流程图_计算图_62,根据函数关系,我们可以得到如下链式求导公式:
反向传播算法代码 java 反向传播算法流程图_反向传播算法_65
  可以看到,整个计算过程是从后向前依次计算的,这也符合反向传播算法对此过程的描 述。当我们得到反向传播算法代码 java 反向传播算法流程图_计算图_62后,就可以利用参数更新公式来更新参数:
反向传播算法代码 java 反向传播算法流程图_计算图_67
  同理,我们也可以得到关于参数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_68的链式求导公式:
反向传播算法代码 java 反向传播算法流程图_人工智能_69
  然后再使用参数更新公式来更新参数即可:
反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_70
  以上过程就是反向传播算法的全部过程,当然,在深度学习的应用中,反向传播算法需要计算的神经元个数是非常多的,但是原理与此无异,就是链式求导法则的一个应用。而在链式求导公式中的许多参数已经在上一步计算好了,所以不需要重复计算,这就使在神经网络训练过程中,节省很多不必要的计算,故反向传播算法就是神经网络中加速计算参数梯度值的方法。

二、计算图

2.1 什么是传播图?

反向传播算法代码 java 反向传播算法流程图_算法_59,通过与真实值的比较得到损失函数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_60,整个过程如下图所示:

反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_73


图3:有两个神经元的前向传播过程


  从左往右的计算过程也叫前向传播,这个很好理解,就是一级一级的向下计算传播。那么计算机中如何表示这个过程呢?计算机会将每个运算小步骤保存下来,记录为一个参数,等待下次运算,具体的运算过程可见下图:

反向传播算法代码 java 反向传播算法流程图_算法_74


图4:有两个神经元的前向传播计算图


反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_75就是为了存储参数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_03反向传播算法代码 java 反向传播算法流程图_算法_77运算小步骤的结果,然后等待下次运算的时候,直接将其作为参数进行计算即可,后面的参数同理,很明显这样看起来更“舒服”,符合计算机逐步运算的逻辑,经过运算最终可以得到最终的预测值反向传播算法代码 java 反向传播算法流程图_算法_59和损失函数反向传播算法代码 java 反向传播算法流程图_人工智能_79,这种模块化的计算过程图就称为计算图(Computation Graphs)。

反向传播算法代码 java 反向传播算法流程图_算法_59和损失函数反向传播算法代码 java 反向传播算法流程图_人工智能_79了,那么计算机又如何利用计算图进行反向传播呢?也就是说,计算机如何利用计算图计算出损失函数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17对于各个参数的梯度呢?对于这个计算过程,可见下图:

反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_83


图5:有两个神经元的反向传播计算图


反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17对于各个参数的梯度仍和之前的计算方法一致,只是由于引入了中间变量反向传播算法代码 java 反向传播算法流程图_反向传播算法_85,所以每次计算关于损失函数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17对于反向传播算法代码 java 反向传播算法流程图_反向传播算法_87中参数的偏导数时,要先计算损失函数反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17对于反向传播算法代码 java 反向传播算法流程图_反向传播算法_87的偏导数,其余计算过程并没有变化,其中需要注意:

  • 黄色的变量:通过此步骤的之前步骤得到的运算结果
  • 绿色的变量:通过前向传播得到的已知的数据

反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_17关于参数反向传播算法代码 java 反向传播算法流程图_计算图_91的偏导数,然后就可以利用之前介绍的梯度下降算法进行参数的优化更新了。

2.2 一个简单的例子

  我们现在已经明白什么是计算图了,那么计算机如何利用计算图的原理去进行有关深度学习的计算呢?我们以Pytorch中的乘法运算为例,其运算图如下所示:

反向传播算法代码 java 反向传播算法流程图_人工智能_92


图6:神经网络中乘法运算的计算图


  1. 前向传播
    Pytorch中乘法运算的前向传播代码如下所示:
class Multiply(torch.autograd.Function):
	@staticmethod
	def forward(ctx, x, y):
	ctx.save_for_backward(x,y)
	z = x * y
	return z

可以看到,整段代码的运算过程恰如乘法运算计算图中绿色所示部分,直接获取到关于反向传播算法代码 java 反向传播算法流程图_计算图_93反向传播算法代码 java 反向传播算法流程图_反向传播算法代码 java_94的参数,然后进行相乘得到反向传播算法代码 java 反向传播算法流程图_算法_95,最后返回反向传播算法代码 java 反向传播算法流程图_算法_95即可。

  1. 反向传播
    Pytorch中乘法运算的反向传播代码如下所示:
class Multiply(torch.autograd.Function):
	@staticmethod
	def backward(ctx, grad_z):
	x, y = ctx.saved_tensors
	grad_x = grad_z * y
	grad_y = grad_z * x
	return grad_x, grad_y

这个就和前向传播的代码有所不同,因为其计算需要求损失函数反向传播算法代码 java 反向传播算法流程图_算法_97关于各个参数的偏导数,所以此段代码对应乘法运算计算图中黄色所示部分。其中,反向传播算法代码 java 反向传播算法流程图_人工智能_98,而我们需要求得损失函数反向传播算法代码 java 反向传播算法流程图_算法_97分别对反向传播算法代码 java 反向传播算法流程图_算法_100的偏导数,其具体表示为:

  • grad_x = grad_z * y对应:

反向传播算法代码 java 反向传播算法流程图_算法_101

  • grad_y = grad_z * x对应:
    反向传播算法代码 java 反向传播算法流程图_算法_102

  这样就可以通过计算图利用反向传播算法来更新数以亿计的网络参数了。


总结

  以上就是本篇博客的全部内容了,文章内容不算太长,但是有些地方还是不太好理解的,最好有些高等数学的基础,学起来会更“舒服”。本系列还会一直更新,敬请期待!