一、加载与Model中参数不一致的预训练模型
我们在构造好了一个模型后,可能要加载一些训练好的模型参数。举例子如下:
假设 trained.pth 是一个训练好的网络的模型参数存储
model = Net()是我们刚刚生成的一个新模型,我们希望model将trained.pth中的参数加载加载进来,但是model中多了一些trained.pth中不存在的参数,如果使用下面的命令:
state_dict = torch.load('trained.pth')
model.load_state_dict(state_dict)
会报错,说key对应不上,因为model你强人所难,我堂堂trained.pth没有你的那些个零碎玩意,你非要向我索取,我上哪给你弄去。但是model不干,说既然你不能完全满足我的需要,那么你有什么我就拿什么吧,怎么办呢?下面的指令代码就行了:
model.load_state_dict(state_dict, strict=False)
二、复制训练好的模型参数
net_path = 'PatAdaAttn-epoch20.pth'
checkpoint = torch.load(net_path)
state_dict = {}
for k, v in checkpoint['model'].items():
if 'smoother' not in k:
state_dict.update({k: v})
此时state_dict已经复制了PatAdaAttn-epoch20.pth中的'model'的参数。
三、保存加载自定义模型
上面保存加载的 ‘PatAdaAttn-epoch20.pth’ 其实一个字典,通常包含如下内容:
1)网络结构:输入尺寸、输出尺寸以及隐藏层信息,以便能够在加载时重建模型。
2)模型的权重参数:包含各网络层训练后的可学习参数,可以在模型实例上调用 state_dict() 方法来获取,比如前面介绍只保存模型权重参数时用到的 model.state_dict()。
3)优化器参数:有时保存模型的参数需要稍后接着训练,那么就必须保存优化器的状态和所其使用的超参数,也是在优化器实例上调用 state_dict() 方法来获取这些参数。
4)其他信息:有时我们需要保存一些其他的信息,比如 epoch,batch_size 等超参数。
知道了这些,那么我们就可以自定义需要保存的内容,比如:
1 # saving a checkpoint assuming the network class named ClassNet
2 checkpoint = {'model': ClassNet(),
3 'model_state_dict': model.state_dict(),
4 'optimizer_state_dict': optimizer.state_dict(),
5 'epoch': epoch}
6
7 torch.save(checkpoint, 'checkpoint.pkl')
上面的 checkpoint 是个字典,里面有4个键值对,分别表示网络模型的不同信息。
然后我们要加载上面保存的自定义的模型:
1 def load_checkpoint(filepath):
2 checkpoint = torch.load(filepath)
3 model = checkpoint['model'] # 提取网络结构
4 model.load_state_dict(checkpoint['model_state_dict']) # 加载网络权重参数
5 optimizer = TheOptimizerClass()
6 optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 加载优化器参数
7
8 for parameter in model.parameters():
9 parameter.requires_grad = False
10 model.eval()
11
12 return model
13
14 model = load_checkpoint('checkpoint.pkl')
如果加载模型只是为了进行推理测试,则将每一层的 requires_grad 置为 False,即固定这些权重参数;还需要调用 model.eval() 将模型置为测试模式,主要是将 dropout 和 batch normalization 层进行固定,否则模型的预测结果每次都会不同。
如果希望继续训练,则调用 model.train(),以确保网络模型处于训练模式。
state_dict() 也是一个Python字典对象,model.state_dict() 将每一层的可学习参数映射为参数矩阵,其中只包含具有可学习参数的层(卷积层、全连接层等)。
比如下面这个例子:
1 # Define model
2 class TheModelClass(nn.Module):
3 def __init__(self):
4 super(TheModelClass, self).__init__()
5 self.conv1 = nn.Conv2d(3, 8, 5)
6 self.bn = nn.BatchNorm2d(8)
7 self.conv2 = nn.Conv2d(8, 16, 5)
8 self.pool = nn.MaxPool2d(2, 2)
9 self.fc1 = nn.Linear(16 * 5 * 5, 120)
10 self.fc2 = nn.Linear(120, 10)
11
12 def forward(self, x):
13 x = self.pool(F.relu(self.conv1(x)))
14 x = self.bn(x)
15 x = self.pool(F.relu(self.conv2(x)))
16 x = x.view(-1, 16 * 5 * 5)
17 x = F.relu(self.fc1(x))
18 x = F.relu(self.fc2(x))
19 x = self.fc3(x)
20 return x
21
22 # Initialize model
23 model = TheModelClass()
24
25 # Initialize optimizer
26 optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
27
28 print("Model's state_dict:")
29 for param_tensor in model.state_dict():
30 print(param_tensor, "\t", model.state_dict()[param_tensor].size())
31
32 print("Optimizer's state_dict:")
33 for var_name in optimizer.state_dict():
34 print(var_name, "\t", optimizer.state_dict()[var_name])
输出为:
Model's state_dict:
conv1.weight torch.Size([8, 3, 5, 5])
conv1.bias torch.Size([8])
bn.weight torch.Size([8])
bn.bias torch.Size([8])
bn.running_mean torch.Size([8])
bn.running_var torch.Size([8])
bn.num_batches_tracked torch.Size([])
conv2.weight torch.Size([16, 8, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([10, 120])
fc2.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0,.....]
可以看到 model.state_dict() 保存了卷积层,BatchNorm层和最大池化层的信息;而 optimizer.state_dict() 则保存的优化器的状态和相关的超参数。