一、知识蒸馏原理

一般较深的复杂的模型往往会有较好的结果(例如Bert预训练模型),但是有时候为了控制时间及硬件成本,我们不得不选择较小的模型,那如何让小的模型逼近复杂的模型,则是通过知识蒸馏的方式实现。

(1)采用传统方式训练一个教师网络。

(2)建立学生网络模型,模型的输出采用传统的softmax函数,拟合目标为one-hot形式的训练集输出,它们之间的距离记为loss1。

(3)将训练完成的教师网络的softmax分类器加入温度参数T,作为具有相同温度参数softmax分类器的学生网络的拟合目标,他们之间的距离记为loss2。

(4)引入参数alpha,将loss1×(1-alpha)+loss2×alpha作为网络训练时使用的loss,训练网络。

二、代码实现

参考了下

#建立student模型  以及初始化超参数

model = studentNet()

criterion = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss()

optimizer = optim.Adam(model.parameters(),lr = 0.0001)

correct_ratio = []
alpha = 0.5

#开始训练
for epoch in range(200):
    loss_sigma = 0.0
    correct = 0.0
    total = 0.0
    for i, data in enumerate(trainload):
        inputs, labels = data
        #inputs = inputs.cuda()
        #labels = labels.cuda()
        labels = labels.squeeze().long()
        optimizer.zero_grad()
        
        outputs = model(inputs.float())
        loss1 = criterion(outputs, labels)
        
        teacher_outputs = teach_model(inputs.float())
        T = 2
        outputs_S = F.log_softmax(outputs/T,dim=1)
        outputs_T = F.softmax(teacher_outputs/T,dim=1)
        loss2 = criterion2(outputs_S,outputs_T)*T*T
        
        loss = loss1*(1-alpha) + loss2*alpha

#        loss = loss1
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(outputs.data, dim = 1)
        total += labels.size(0)
        correct += (predicted.cpu()==labels.cpu()).squeeze().sum().numpy()
        loss_sigma += loss.item()
        if i% 100 == 0:
            loss_avg = loss_sigma/10
            loss_sigma = 0.0
            print('loss_avg:{:.2}   Acc:{:.2%}'.format(loss_avg, correct/total))
            print("Train Epoch: {} [{}/{} ({:0f}%)]\tLoss: {:.6f}".format(
                epoch, i * len(data), len(trainload.dataset), 
                100. * i / len(trainload), loss.item()
            ))

三、特别注意的点

(1)这里会看到,output_S和output_T采用了不同的softmax函数

criterion2 = nn.KLDivLoss()     

outputs_S = F.log_softmax(outputs/T,dim=1)
outputs_T = F.softmax(teacher_outputs/T,dim=1)

loss2 = criterion2(outputs_S,outputs_T)*T*T

这是因为求KL散度时,前一个值序列必须取log,要不会出现KL散度为负数的情况(一般来说KL散度大于等于0,但是pytorch实现该函数有点问题)。

(2)有时候会出现teacher模型和student模型对数据集的处理格式不一致,因此不能用一个dataloader,我这里贴出我的代码以供参考

for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))

        train1=iter(train_iter)     #teach模型数据
        train2=iter(train_iter2)    #student模型数据

        for i in range(len(train_iter)): #这里不用enumerate(dataloadr)
            try:
                trains, labels = next(train1) #用next进行寻找下一个batch
            except:    #当所有的batch被穷尽,开始新的epoch时,需要重新创建迭代器
                del train1
                train1 = iter(train_iter)
                trains, labels = next(train1)
            #student模型数据同理
            try:
                trains2, labels2 = next(train2)
            except:
                del train2
                train2 = iter(train_iter2)
                trains2, labels2 = next(train2)

(3)关于KL散度的碎碎念

KL散度的概念以及pytorch实现函数可参见:

https://www.jianshu.com/p/98ec08ea3bec

pytorch中通过torch.nn.KLDivLoss类实现,也可以直接调用F.kl_div 函数,代码中的size_averagereduce已经弃用。reduction有四种取值mean,batchmean, sum, none一般默认为mean,具体区别也可参照上面链接

(4)温度的选择

在我们正常的训练过程中,我们只会关注概率最高结果与正确结果的差别。这种相似性完全是通过足够数量的样本构建的。

因此Hinton的想法是:如何充分利用大网络中的这种结果?如果只是构建所有类的传统损失函数的话,小概率结果对损失函数的贡献微乎其微。解决的方法无非是:在计算损失函数时放大其他类的概率值所对应的损失值。

Hinton用一个简单的方法解决了这一问题:加入温度系数T。

具体例子可参见:


目前看到T可以设置为2~20

(5)alpha的选择

因为目标就是为了学习教师网络的的知识,因此在loss上可以分配更多的权重,例如一开始就设置为接近1的值,如0.95。然后逐步调低,查看分类准确率

四、其他

还有一些关于知识蒸馏的综述,有很多好的思路,多教师蒸馏,GAN蒸馏,以后尝试了再进行补充