Batch Normalization原理解析


目录

  • Batch Normalization原理解析
  • 前言
  • 1.1梯度消失和梯度爆炸
  • 2.1内部协方差转移
  • 3.1Batch Normalization原理


前言

本文章是自己参考一些书籍和博客整理的一些Batch Normalization相关资料,通篇是基于自己的理解进行的整理,以作为日后参考使用。参考资料在文后贴出。

Batch Normalization可以用于解决梯度消失和梯度爆炸问题,也包括原论文里提到的内部协方差转移(Internal Covariate Shift),所以本文章先整理了一些梯度消失和梯度爆炸以及内部协方差转移出现的原理,然后再进行Batch Normalization原理的解析。

1.1梯度消失和梯度爆炸

在一些论文(比如resnet那篇)和技术书籍中,Batch Normalization被提到可以用于解决梯度消失和梯度爆炸,在此参考《深入浅出Pytorch》这本书,给出梯度消失和梯度爆炸出现的原理。

RNN梯度爆炸matmul batch normalization梯度爆炸_深度学习


其中RNN梯度爆炸matmul batch normalization梯度爆炸_RNN梯度爆炸matmul_02为第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_03层神经元的输入,RNN梯度爆炸matmul batch normalization梯度爆炸_RNN梯度爆炸matmul_04为第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_03层神经元的权重,而RNN梯度爆炸matmul batch normalization梯度爆炸_RNN梯度爆炸matmul_06为该层的输出,即作为下层的输入,理论上RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_07,加上激活函数后RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_08

根据微积分里的链式法则,RNN梯度爆炸matmul batch normalization梯度爆炸_深度学习_09RNN梯度爆炸matmul batch normalization梯度爆炸_深度学习_10的求导为:
RNN梯度爆炸matmul batch normalization梯度爆炸_RNN梯度爆炸matmul_11
我们假设最后的损失函数是RNN梯度爆炸matmul batch normalization梯度爆炸_batch_12,是输出层神经元的函数,对两边求导,根据链式法则:
RNN梯度爆炸matmul batch normalization梯度爆炸_RNN梯度爆炸matmul_13
RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_14
其中RNN梯度爆炸matmul batch normalization梯度爆炸_batch_15式可以看成损失函数对数据的导数,即数据梯度;而RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_16为损失函数对权重的导数,即权重梯度。从公式二可以看出,数据梯度和权重有关,权重梯度和数据有关,而前一层的数据梯度和权重梯度都和后一层的数据梯度有关。

接下来就可以解释梯度消失和梯度爆炸了:

  • 梯度消失:当构建的神经网络非常深时,不同的层学习的速度差异很大,表现为网络中靠近输出的层学习的情况很好,靠近输入的层学习的很慢。造成该问题的原因有很多,比如权重初始化不当,或者激活函数使用不当,拿激活函数举例更容易理解。

    若使用Sigmoid或Tanh作为激活函数,它们的特点为梯度小于1。那意味反向传播时每次往下一级传播时激活函数对数据的导数都小于1,即RNN梯度爆炸matmul batch normalization梯度爆炸_batch_17小于1。而每次往前一级传播数据梯度都会乘以RNN梯度爆炸matmul batch normalization梯度爆炸_batch_17,所以传播得越深,最后的数据梯度越小,对应的权重梯度也越小,就造成了梯度消失。所以在构建网络时,我们通常会用ReLU函数作为激活函数,因为它梯度为1。若权重初始化不当,比如一些权重过小,也会造成该问题。
  • 梯度爆炸:如果权重初始化把一些权重取值太大,那么在反向传播时,每向前传播一级,数据梯度都会变大,对应的权重梯度也会叠加变大,所以造成靠近输入层的权重梯度过大。

综上,权重初始化和激活函数是造成梯度消失和梯度爆炸的主要原因,所以权重初始化时尽量将权重初始化值分布在1附近。

2.1内部协方差转移

内部协方差转移是在Batch Normalization这篇论文里提到的。上文说到深度神经网络涉及到很多层的叠加,每一层的参数更新会导致上层的输入数据分布发生变化,通过层层叠加,高层的输入分布变化会非常剧烈,这就使得高层需要不断去重新适应底层的参数更新。

也就是说我们输入的数据,经过网络的每一层都会进行一次非线性变换,一直到最后一层,此时的输入数据的分布已经被改变了,但ground truth是不会变的,这就造成了网络中靠后的神经元需要不断适应更新参数适应新的数据分布,并且每一层的更新都会影响下一层的变化,所以在优化器参数设置上需要非常谨慎。

3.1Batch Normalization原理

下面是关于Batch Normalization原理的分析,为了解决内部协方差转移,必须让网络的每一层的输入都满足独立同分布才行,而这就是Batch Normalization的作法。

拿卷积神经网络举例。假设我们网络的某一层有RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19个神经元,它的前一层有RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_03个神经元,则第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_03层的输出即为[B,j,H1,W1],其中B为Batch_Size,j为该层输出的通道数。第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_03层的输出作为输入传递给第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层,而第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层有RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19个神经元,相当于该层的输出通道数为RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19个,即第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层的每一个神经元的权重维度为[j,S,S],S为卷积核大小,每个神经元的权重与输入[B,j,H1,W1]进行卷积操作得到维度[B,1,H2,W2],而一共有RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19个这样的神经元,所以第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层输出的整体维度为[B,k,H2,W2]。可以看下图加深理解:

RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_30


上图第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_03层输出为[B,4,H1,W1],作为输入传递给第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层,第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层有两个神经元,每个神经元的权重维度为[4,S,S],但个神经元的权重与输入作卷积操作,得到的结果为[B,1,H2,W2],那么两个神经元的结果进行cat操作,得到整体结果[B,2,H2,W2]。Batch Normalization就是作用在第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层的输出上的,继续假设第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层有RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19个神经元,Batch_Size 为RNN梯度爆炸matmul batch normalization梯度爆炸_batch_37,表示RNN梯度爆炸matmul batch normalization梯度爆炸_batch_37个数据,所以第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层输出的维度为RNN梯度爆炸matmul batch normalization梯度爆炸_batch_40,相当于一共RNN梯度爆炸matmul batch normalization梯度爆炸_batch_37个数据,每个数据有RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19个通道,每个通道为RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_43的矩阵,而Batch Normalization就是对RNN梯度爆炸matmul batch normalization梯度爆炸_batch_37个数据的每一个维度作正则化,如下图:

RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_45


以上是我们使用BN时把它添加的位置,一般一个Conv层后就要接一个BN层,然后再接ReLU等激活层。下面再看一下BN的具体公式。

继续用上面提到的例子,RNN梯度爆炸matmul batch normalization梯度爆炸_batch_37个数据经过第RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19层得到了维度为RNN梯度爆炸matmul batch normalization梯度爆炸_batch_40的输出,即RNN梯度爆炸matmul batch normalization梯度爆炸_batch_37个数据,每个数据有RNN梯度爆炸matmul batch normalization梯度爆炸_机器学习_19个通道,每个通道为RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_43的矩阵。对该输出进行Batch Normalization,就是把RNN梯度爆炸matmul batch normalization梯度爆炸_batch_37个数据的每一个通道提出来,进行正则化:
RNN梯度爆炸matmul batch normalization梯度爆炸_batch_53
其中RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_54表示整个Batch的第一个通道,RNN梯度爆炸matmul batch normalization梯度爆炸_RNN梯度爆炸matmul_55表示第RNN梯度爆炸matmul batch normalization梯度爆炸_RNN梯度爆炸matmul_56个数据的第一个通道。该操作可以分为两步:

  • Standardization:首先对RNN梯度爆炸matmul batch normalization梯度爆炸_深度学习_57RNN梯度爆炸matmul batch normalization梯度爆炸_深度学习_57进行 Standardization,得到 zero mean unit variance的分布RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_59;
  • scale and shift:然后再对RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_59进行scale and shift,缩放并平移到新的分布RNN梯度爆炸matmul batch normalization梯度爆炸_RNN梯度爆炸matmul_61,具有新的均值方差RNN梯度爆炸matmul batch normalization梯度爆炸_batch_62

RNN梯度爆炸matmul batch normalization梯度爆炸_batch_63RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_64为待学习的scale和shift参数,用于控制RNN梯度爆炸matmul batch normalization梯度爆炸_人工智能_65的方差和均值。