在深度学习网络训练中,交叉熵损失是一种经常使用的损失函数,这篇文章里我们来推导一下交叉熵损失关于网络输出z的导数,由于二分类是多分类的特殊情况,我们直接介绍多分类的推导过程。
一、Softmax交叉熵损失求导
基于softmax的多分类交叉熵公式为
其中表示类别总数,包含背景类别,
通过
计算得到,
是网络的输出。
是真实标签,通常由one-hot形式编码,单独一个样本的标签如下:
表示这个样本属于
类。
我们拿1个属于c类的样本来举例,网络输出为z,因为总共有类,所以网络有
个
值,
,然后经过Softmax激活得到
个和为1的概率值
,该样本的真实标签
只有
,其余都为0,每一类的损失是:-1x标签xlog(概率值),最后求和得到总损失。
可以知道,类样本的标签编码中除了
=1外,其他值
都为0,所以这个样本对应的其他类的交叉熵都为0,总损失可以化简为:
下面我们来计算一下损失对每个
的导数。当
,该类对应的损失为0,求导时无用,但是由于激活函数是Softmax,计算
时
被用到(分母),所以不管
是否为0,对
求导时,都需要考虑
类对应的概率值
。
对求导需要用到链式求导法则,即
当时,
代入得
当时
代入,
所以:
二、Sigmoid交叉熵损失求导
sigmoid一般是用在二分类问题中,二分类时,网络只有一个输出值,经过sigmoid函数得到该样本是正样本的概率值。损失函数如下:
使用Sigmoid函数做多分类时,相当于把每一个类看成是独立的二分类问题,类之间不会相互影响。真实标签只表示j类的二分类情况。
基于sigmoid的多分类交叉熵公式如下:
其中通过
计算得到,即sigmoid函数,表达式如下:
sigmoid函数的导数如下:
我们拿1个属于c类的样本来举例,网络输出为z,因为总共有类,所以网络有
个
值,
,然后经过sigmoid激活得到
个独立的概率值
,该样本的真实标签
只有
,其余都为0。每一类都是一个单独的二分类问题,通过二分类交叉熵来计算损失,最后把所有类的损失相加。
现在我们计算损失关于网络输出
的导数
,这里需要用到链式法则,在计算Loss对
的导数时,只需要考虑该类对应的
即可,因为其他类的概率值跟
没有关系。
当时,
:
当时,
:
所以
三、总结
不管是使用sigmoid还是softmax作为最后的分类器,损失函数关于网络输出z的导数的形式是一样的。