一、知识蒸馏原理
一般较深的复杂的模型往往会有较好的结果(例如Bert预训练模型),但是有时候为了控制时间及硬件成本,我们不得不选择较小的模型,那如何让小的模型逼近复杂的模型,则是通过知识蒸馏的方式实现。
(1)采用传统方式训练一个教师网络。
(2)建立学生网络模型,模型的输出采用传统的softmax函数,拟合目标为one-hot形式的训练集输出,它们之间的距离记为loss1。
(3)将训练完成的教师网络的softmax分类器加入温度参数T,作为具有相同温度参数softmax分类器的学生网络的拟合目标,他们之间的距离记为loss2。
(4)引入参数alpha,将loss1×(1-alpha)+loss2×alpha作为网络训练时使用的loss,训练网络。
二、代码实现
参考了下
#建立student模型 以及初始化超参数
model = studentNet()
criterion = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.0001)
correct_ratio = []
alpha = 0.5
#开始训练
for epoch in range(200):
loss_sigma = 0.0
correct = 0.0
total = 0.0
for i, data in enumerate(trainload):
inputs, labels = data
#inputs = inputs.cuda()
#labels = labels.cuda()
labels = labels.squeeze().long()
optimizer.zero_grad()
outputs = model(inputs.float())
loss1 = criterion(outputs, labels)
teacher_outputs = teach_model(inputs.float())
T = 2
outputs_S = F.log_softmax(outputs/T,dim=1)
outputs_T = F.softmax(teacher_outputs/T,dim=1)
loss2 = criterion2(outputs_S,outputs_T)*T*T
loss = loss1*(1-alpha) + loss2*alpha
# loss = loss1
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, dim = 1)
total += labels.size(0)
correct += (predicted.cpu()==labels.cpu()).squeeze().sum().numpy()
loss_sigma += loss.item()
if i% 100 == 0:
loss_avg = loss_sigma/10
loss_sigma = 0.0
print('loss_avg:{:.2} Acc:{:.2%}'.format(loss_avg, correct/total))
print("Train Epoch: {} [{}/{} ({:0f}%)]\tLoss: {:.6f}".format(
epoch, i * len(data), len(trainload.dataset),
100. * i / len(trainload), loss.item()
))
三、特别注意的点
(1)这里会看到,output_S和output_T采用了不同的softmax函数
criterion2 = nn.KLDivLoss()
outputs_S = F.log_softmax(outputs/T,dim=1)
outputs_T = F.softmax(teacher_outputs/T,dim=1)loss2 = criterion2(outputs_S,outputs_T)*T*T
这是因为求KL散度时,前一个值序列必须取log,要不会出现KL散度为负数的情况(一般来说KL散度大于等于0,但是pytorch实现该函数有点问题)。
(2)有时候会出现teacher模型和student模型对数据集的处理格式不一致,因此不能用一个dataloader,我这里贴出我的代码以供参考
for epoch in range(config.num_epochs):
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
train1=iter(train_iter) #teach模型数据
train2=iter(train_iter2) #student模型数据
for i in range(len(train_iter)): #这里不用enumerate(dataloadr)
try:
trains, labels = next(train1) #用next进行寻找下一个batch
except: #当所有的batch被穷尽,开始新的epoch时,需要重新创建迭代器
del train1
train1 = iter(train_iter)
trains, labels = next(train1)
#student模型数据同理
try:
trains2, labels2 = next(train2)
except:
del train2
train2 = iter(train_iter2)
trains2, labels2 = next(train2)
(3)关于KL散度的碎碎念
KL散度的概念以及pytorch实现函数可参见:
https://www.jianshu.com/p/98ec08ea3bec
pytorch中通过torch.nn.KLDivLoss
类实现,也可以直接调用F.kl_div
函数,代码中的size_average
与reduce
已经弃用。reduction有四种取值mean
,batchmean
, sum
, none
一般默认为mean,具体区别也可参照上面链接
(4)温度的选择
在我们正常的训练过程中,我们只会关注概率最高结果与正确结果的差别。这种相似性完全是通过足够数量的样本构建的。
因此Hinton的想法是:如何充分利用大网络中的这种结果?如果只是构建所有类的传统损失函数的话,小概率结果对损失函数的贡献微乎其微。解决的方法无非是:在计算损失函数时放大其他类的概率值所对应的损失值。
Hinton用一个简单的方法解决了这一问题:加入温度系数T。
具体例子可参见:
目前看到T可以设置为2~20
(5)alpha的选择
因为目标就是为了学习教师网络的的知识,因此在loss上可以分配更多的权重,例如一开始就设置为接近1的值,如0.95。然后逐步调低,查看分类准确率
四、其他
还有一些关于知识蒸馏的综述,有很多好的思路,多教师蒸馏,GAN蒸馏,以后尝试了再进行补充