BP神经网络
1 前向传播
BP神经网络的结构如图所示
参数的定义如下:
(1)
:表示网络层数
(2)
:表示第
层,
是输入层,
是输出层,其他为隐含层
(3)
:表示第
层第
个单元与第
层第
个单元的连接权重
(4)
:表示第
层第
个单元的偏置值(激活阈值)
(5)
:表示第
层第
个单元的权重累计
(6)
:表示第
层第
个单元的激活值
(7)
:表示最后的输出值
(8)
:表示第
层的神经元个数
(9)样本个数为
,特征个数为
多层感知机中,输入信号通过各个网络层的隐节点产生输出的过程称为前向传播
- 当 时,
- 当 时,计算过程为:
- 当 时,网络输出值为:
拓展为向量的表达形式为:
依此类推,给定第
层的激活值
,则第
层的激活值
的计算为:
2 反向传播
2.1 损失函数
在网络训练中,前向传播最终产生一个标量损失函数,反向传播算法(Back Propagation)则将损失函数的信息沿网络层向后传播用以计算梯度,达到优化网络参数的目的。
对于单个样本,平方误差损失函数为:
对于全部样本,平方误差损失函数为:
第一项为均方差项(经验风险),第二项为正则项,在功能上可称作权重衰减项, 目的是减小权重的幅度,达到结构风险最小化(SRM),防止模型过拟合。该项之前的系数
为权重衰减参数,用于控制损失函数中两项的相对权重。
以二分类场景下,交叉熵损失函数为:
其中,第一项衡量了预测
与真实类别
之间的交叉熵,当
与
相等时,熵最大,也就是损失函数最小
在多分类场景下,交叉熵损失函数为:
其中,
代表第
个样本的预测属于类别
的概率,
表示实际的概率,如果第
个样本真是类别为
,则
,否则为0
2.2 损失函数的求导(平方误差)
可以看出,损失函数是关于所有权重
和偏置
的方程,通过最小化损失函数来求得最佳的权重
和偏置
,可以采用梯度下降法迭代求得:
现在就是要求各权重
和偏置
的偏导,直接求上述式子会比较困难,而BP反向传播算法就是一种方便的求解偏导数的方法,可以认为是一种从后往前找规律的方法。
对于第
层的参数,可以容易得到:
现在的问题转化为求每个样本情况下的损失函数关于各权重
和偏置
的偏导
是关于上一层的
和
的函数,假设使用平方误差损失函数,对于每一个样本均有:
从最后一层开始算起:
按照上述定义,对于
层的神经网络,
和
的上标只能取到
,并且有:
根据链式法则:
此时设:
对于最后一层,由于激活函数,
和
都是已知的,所以
是可以确定的
对于倒数第二层:
只需要求
即可,使用链式法则:
表示当前层的第
个神经元,而第
个神经元的
受到前一层的所有的第
个
影响,故有:
表示
层的第
个神经元,而
是
层的第
个神经元。经上述求导,其实就是
层的第
个神经元和
层的第
个神经元的权重
与
的乘积,故:
可以进一步归纳出:
从而得到:
最终的梯度下降方程为:
2.3 BP神经网络的算法流程(平方误差)
(1)进行前向传导计算,得到
的激活值
(2)对于最后一层即
层,计算误差
(3)对
层计算误差
(4)权重和偏置更新,当
的时候,
就是输入
如果考虑正则化,则权重的更新方程为:
3 交叉熵损失函数分析
已知bp神经网络每层的输出为:
对于平方误差损失函数:
最后一层的误差
为:
对于交叉熵损失函数:
在分类问题中,
仅在一个类别
时取值为1,其余为0,设实际的类别为
,则:
最后一层的误差
为:
当
取
Sigmoid函数时:
所以有:
一般来说,平方损失函数更适合输出为连续,并且最后一层不含Sigmoid或Softmax激活函数的神经网络。交叉熵损失更适合二分类或者多分类的场景
为什么平方误差损失函数不适合最后一层含Sigmoid或Softmax激活函数的神经网络呢?
当使用交叉熵损失函数时,最后一层的误差为
,其中最后一项为
,为激活函数的导数。当激活函数为Sigmoid函数时,如果
的值非常大,函数的梯度趋于饱和,即
的绝对值非常小,导致
的取值也非常小,使得基于梯度的学习速度非常缓慢
当使用平方误差损失函数时,最后一层的误差为
,此时导数是线性的,因此不存在学习速度过慢的问题
说明:
(1)引入交叉熵损失函数目的是解决一些实例在刚开始训练时学习得非常慢的问题,其主要针对激活函数为Sigmod 函数,如果在输出神经元是S型神经元时,交叉熵一般都是更好的选择,交叉熵无法改善隐藏层中神经元发生的学习缓慢,交叉熵损失函数只对网络输出明显背离预期时发生的学习缓慢有改善效果, 交叉熵损失函数并不能改善或避免神经元饱和,而是当输出层神经元发生饱和时,能够避免其学习缓慢的问题
(2)如果采用一种不会出现饱和状态的激活函数,那么可以继续使用平方误差作为损失函数,输出神经元是线性的那么二次代价函数不再会导致学习速度下降的问题,在此情形下,平方误差损失函数就是一种合适的选择
4 参考资料:
[1] 神经网络与机器学习
[2] 深度眸机器学习课程
[3] 百面机器学习
[4] Python自然语言处理实战:核心技术与算法