如果我们对模型进行微调,比如改变模型的最后一个全连接层输出节点个数,此时我们再载入原来模型的预训练权重,就会报错。因为模型的结构已经发生了变化,所以它所对应的权重(通常以字典中键值对的形式存储)存储结构也会发生变化,所以载入的过程中就会出现不匹配的情况。此时我们有两种解决方法,第一种方法就是只载入模型的部分权重,比如本例中就只载入除了最后一个全连接后的所有权重。第二种方法是先不改变模型,直接使用原来的模型载入原来的预训练权重,然后再改动这个模型的结构,比如改变模型的最后一个全连接层输出节点个数。下面我们分别介绍一下方法一和方法二 的代码实现:
注意:在载入模型的部分权重时,这些部分权重通常包括底层网络结构的权重。因为底层的权重都是比较通用的权重,载入之后对于后续训练还是很有帮助的,假如我们只载入高层网络结构的话对于后续的训练没有太大帮助,况且我们在微调模型时往往是对模型的最后几层结构进行变动,所以此时只会载入模型的底层预训练权重,而高层权重是需要重新训练的。
方法一:
import os
import torch
import torch.nn as nn
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth" #这是在imagenet数据集上预训练的模型权重
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# 改变网络最后一个全连接层的结构,将该全连接层的输出节点变成5
in_channel = net.fc.in_features # 获取resnet最后的一个全连接层(它也是renset唯一的全连接层)的输入节点个数
net.fc = nn.Linear(in_channel, 5) # 重新创建一个新的resnet的全连接层,该全连接层的输入节点个数是上一步得到的的节点个数,输出节点个数按照实际的分类任务确定,本例是进行5分类任务,所以输出节点个数就是5
if __name__ == '__main__':
main()
方法二:
import os
import torch
import torch.nn as nn
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "resnet34-pre.pth" #这是在imagenet数据集上预训练的模型权重
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
#在实例化模型的同时直接改变模型最后一共全连接层的输出节点个数,由原来的1000改为5.这里的1000指的是原来的rensnet是在imagenet数据集上进行训练的,而该数 据集有1000个类别
net = resnet34(num_classes=5)
#读取预训练权重,这里的权重是以有序字典进行存储的,其中字典的key是层结构名称,字典的value是该层的权重
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key:
del_key.append(key) #找出最后的一个全连接层(它也是renset唯一的全连接层)
for key in del_key:
del pre_weights[key] #删减最后的一个全连接层(它也是renset唯一的全连接层)所对应的权重
# 参数strict表示是否严格的载入预训练权重的每一个部分,它的默认值值是True,但是因为我们这里是载入部分权重,所以应该将其设置为False。
# 该函数有两个返回值,其中missing_keys指的是在net网络中出现的权重名称并没有在预训练权重pre_weights中出现,相当于漏洞了这些权重。
# unexpected_keys指的是在预训练权重pre_weights中出现的权重名称并没有在net网络中出现
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")
if __name__ == '__main__':
main()
输出结果
[missing_keys]:
fc.weight
fc.bias
[unexpected_keys]:
如果我们不知道模型每一层的名称,则可以通过如下方法获取:
import torchvision.models as models
net=models.resnet34() #这里我们以resnet34为例进行演示
for k,v in net.named_parameters():
print(k)
输出结果:
conv1.weight
bn1.weight
bn1.bias
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.2.conv1.weight
layer1.2.bn1.weight
layer1.2.bn1.bias
layer1.2.conv2.weight
layer1.2.bn2.weight
layer1.2.bn2.bias
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.downsample.0.weight
layer2.0.downsample.1.weight
layer2.0.downsample.1.bias
layer2.1.conv1.weight
layer2.1.bn1.weight
layer2.1.bn1.bias
layer2.1.conv2.weight
layer2.1.bn2.weight
layer2.1.bn2.bias
layer2.2.conv1.weight
layer2.2.bn1.weight
layer2.2.bn1.bias
layer2.2.conv2.weight
layer2.2.bn2.weight
layer2.2.bn2.bias
layer2.3.conv1.weight
layer2.3.bn1.weight
layer2.3.bn1.bias
layer2.3.conv2.weight
layer2.3.bn2.weight
layer2.3.bn2.bias
layer3.0.conv1.weight
layer3.0.bn1.weight
layer3.0.bn1.bias
layer3.0.conv2.weight
layer3.0.bn2.weight
layer3.0.bn2.bias
layer3.0.downsample.0.weight
layer3.0.downsample.1.weight
layer3.0.downsample.1.bias
layer3.1.conv1.weight
layer3.1.bn1.weight
layer3.1.bn1.bias
layer3.1.conv2.weight
layer3.1.bn2.weight
layer3.1.bn2.bias
layer3.2.conv1.weight
layer3.2.bn1.weight
layer3.2.bn1.bias
layer3.2.conv2.weight
layer3.2.bn2.weight
layer3.2.bn2.bias
layer3.3.conv1.weight
layer3.3.bn1.weight
layer3.3.bn1.bias
layer3.3.conv2.weight
layer3.3.bn2.weight
layer3.3.bn2.bias
layer3.4.conv1.weight
layer3.4.bn1.weight
layer3.4.bn1.bias
layer3.4.conv2.weight
layer3.4.bn2.weight
layer3.4.bn2.bias
layer3.5.conv1.weight
layer3.5.bn1.weight
layer3.5.bn1.bias
layer3.5.conv2.weight
layer3.5.bn2.weight
layer3.5.bn2.bias
layer4.0.conv1.weight
layer4.0.bn1.weight
layer4.0.bn1.bias
layer4.0.conv2.weight
layer4.0.bn2.weight
layer4.0.bn2.bias
layer4.0.downsample.0.weight
layer4.0.downsample.1.weight
layer4.0.downsample.1.bias
layer4.1.conv1.weight
layer4.1.bn1.weight
layer4.1.bn1.bias
layer4.1.conv2.weight
layer4.1.bn2.weight
layer4.1.bn2.bias
layer4.2.conv1.weight
layer4.2.bn1.weight
layer4.2.bn1.bias
layer4.2.conv2.weight
layer4.2.bn2.weight
layer4.2.bn2.bias
fc.weight
fc.bias