在深度学习网络训练中,交叉熵损失是一种经常使用的损失函数,这篇文章里我们来推导一下交叉熵损失关于网络输出z的导数,由于二分类是多分类的特殊情况,我们直接介绍多分类的推导过程。

一、Softmax交叉熵损失求导

基于softmax的多分类交叉熵公式为

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax

其中交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_02表示类别总数,包含背景类别,交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_03通过交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_04计算得到,交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_05是网络的输出。交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_06是真实标签,通常由one-hot形式编码,单独一个样本的标签如下:

交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_07

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_08表示这个样本属于交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_08类。

我们拿1个属于c类的样本来举例,网络输出为z,因为总共有交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_10类,所以网络有交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_10交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_12值,交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_13,然后经过Softmax激活得到交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_10个和为1的概率值交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_15,该样本的真实标签交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_16只有交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_17,其余都为0,每一类的损失是:-1x标签xlog(概率值),最后求和得到总损失。

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_18


可以知道,交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_19类样本的标签编码中除了交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_20=1外,其他值交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_21都为0,所以这个样本对应的其他类的交叉熵都为0,总损失可以化简为:

交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_22

下面我们来计算一下损失交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_23对每个交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_24的导数。当交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_25,该类对应的损失为0,求导时无用,但是由于激活函数是Softmax,计算交叉熵损失函数 sigmoid 交叉熵损失函数求导_求导_26交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_24被用到(分母),所以不管交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_21是否为0,对交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_24求导时,都需要考虑交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_19类对应的概率值交叉熵损失函数 sigmoid 交叉熵损失函数求导_求导_26

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_24求导需要用到链式求导法则,即
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_33

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_34时,
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_35
代入交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_36
交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_37

交叉熵损失函数 sigmoid 交叉熵损失函数求导_求导_38
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_39
代入交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_36
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_41

所以:

交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_42

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_43

二、Sigmoid交叉熵损失求导

sigmoid一般是用在二分类问题中,二分类时,网络只有一个输出值,经过sigmoid函数得到该样本是正样本的概率值。损失函数如下:
交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_44
使用Sigmoid函数做多分类时,相当于把每一个类看成是独立的二分类问题,类之间不会相互影响。真实标签交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_21只表示j类的二分类情况。
基于sigmoid的多分类交叉熵公式如下:
交叉熵损失函数 sigmoid 交叉熵损失函数求导_求导_46

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_47
其中交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_48通过交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_49计算得到,即sigmoid函数,表达式如下:
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_50
sigmoid函数的导数如下:
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_51

我们拿1个属于c类的样本来举例,网络输出为z,因为总共有交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_10类,所以网络有交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_10交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_12值,交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_13,然后经过sigmoid激活得到交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_10个独立的概率值交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_15,该样本的真实标签交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_16只有交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_17,其余都为0。每一类都是一个单独的二分类问题,通过二分类交叉熵来计算损失,最后把所有类的损失相加。

交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_60


现在我们计算损失交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_61关于网络输出交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_12的导数交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_63,这里需要用到链式法则,在计算Loss对交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_24的导数时,只需要考虑该类对应的交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_48即可,因为其他类的概率值跟交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_24没有关系。

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_67

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Softmax_68时,交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_69:
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_70

交叉熵损失函数 sigmoid 交叉熵损失函数求导_求导_71时,交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_72:
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失函数 sigmoid_73
所以
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_74

交叉熵损失函数 sigmoid 交叉熵损失函数求导_Sigmoid_75

三、总结

不管是使用sigmoid还是softmax作为最后的分类器,损失函数关于网络输出z的导数的形式是一样的。
交叉熵损失函数 sigmoid 交叉熵损失函数求导_交叉熵损失_76