论文地址:

https://arxiv.org/pdf/1503.02531.pdf


读论文《Distilling the Knowledge in a Neural Network》——蒸馏网络 —— 蒸馏算法 —— 知识蒸馏 中的温度系数到底怎么用, temperature怎么用?_git



蒸馏网络的重要公式:

读论文《Distilling the Knowledge in a Neural Network》——蒸馏网络 —— 蒸馏算法 —— 知识蒸馏 中的温度系数到底怎么用, temperature怎么用?_数据集_02

读论文《Distilling the Knowledge in a Neural Network》——蒸馏网络 —— 蒸馏算法 —— 知识蒸馏 中的温度系数到底怎么用, temperature怎么用?_数据集_03

其中,\(p^g\)为Teacher网络,\(q\)为Student网络。


个体神经网络(CNN模型):

CNN层 + 全连接层(输出的是logits) + softmax层(输出的是预测值概率P) + 交叉熵损失函数


蒸馏算法:

第一步:使用训练数据集训练Teacher网络,这时候的logits是不使用Temperature参数调控的,和正常算法流程一致;

第二步:使用Teacher网络的\(p^{g}\)和Student网络的\(q\)使用\(KL(p^{g}, q)\)来训练Student网络,需要注意这时的\(p^{g}\)和\(q\)都是使用在对各自的logits使用Temerature系数之后的,并且需要注意这里的Temperature可以视作为一个超参数,并且在使用Teacher网络训练Student网络时使用的训练数据集和单独训练Teacher网络的数据集一致;

第三步:完成Student网络训练后进行测试,注意,这时的Student网络是不需要对logits使用Temperature参数调控的,也就是说测试Student网络时是和普通算法流程一致的,是不使用Temperature参数的。



可以说,在蒸馏算法中这个Temperature是一个超参数形式的存在,并且只存在于使用Teacher网络训练Student网络的时候,其主要原因是这时候如果只是使用概率P进行训练则难以解决概率分布不均衡的情况,并且也无法识别不同logits得到相同P的情况,如果只使用logits则也无法计算时的单位不统一的问题,为此论文中提出对logits加入Temperature系数调整,并用KL散度进行Student网络的训练。

很多人不理解这个知识蒸馏算法,其主要障碍就是不理解这个Temperature系数上,其实这个系数只是一个超参,并且在实际训练时也只是起到一个调节的作用,虽然这个Temperature系统对算法的最终performance影响很大,但是却并不神秘。