Pytorch 保存加载模型时的坑

  • Pytorch 保存加载模型时的坑
  • 方法1:保存模型的参数和结构信息
  • 方法二:官方推荐的方法,只保存和恢复模型中的参数
  • 填坑
  • 总结


Pytorch 保存加载模型时的坑

在说Pytorch保存加载模型时的坑之前,先介绍一下pytorch对训练好的模型如何进行保存和加载。

方法1:保存模型的参数和结构信息

保存:

model=MobileNetV2(n_class=2)#加载模型
############进行训练##########
 model = torch.nn.DataParallel(model, device_ids=[int(i) for i in args.gpus.strip().split(',')])#用多gpus 训练×××关键
############进行训练##########
torch.save(model, os.path.join(args.save_path, "epoch_" + str(epoch) + ".pth.tar"))#保存模型

恢复:

model=torch.load(args.load_path)#

这种方法会出现一个问题:当利用pytorch 1.0.0 保存好了模型后,加载时利用pytorch1.1.0 进行load() 时回报错,所以官方推荐使用第二种方法进行加载

方法二:官方推荐的方法,只保存和恢复模型中的参数

一个完整的例子:
迁移学习加载模型(此时 checkpoint 字典只有 state_dict ):

model=MobileNetV2(n_class=2)#加载模型结构
model_dict =  model.state_dict()#获取模型参数(未加载保存的模型参数 )
if args.resume:#模型路径
    if os.path.isfile(args.resume):
        print(("=> loading checkpoint '{}'".format(args.resume)))
        checkpoint = torch.load(args.resume)#获取模型参数
         #因为我修改网络模型进行迁移学习,这一步是在checkpoint里获取没有修改的模型参数state_dict
        state_dict = {k: v for k, v in checkpoint.items() if k in model_dict.keys()}
        model_dict.update(state_dict)#更新已经保存的参数至model_dict
        model.load_state_dict(model_dict)#加载模型参数
    else:
        print(("=> no checkpoint found at '{}'".format(args.resume)))

保存:–这里有坑

torch.save({"epoch":epoch, #一共训练的epoch
                   "model_state_dict":model.module.state_dict(), #保存模型参数×××××这里埋个坑××××
                   'epoch_acc': epoch_acc, #一共训练的epoch
                   "optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
                   ,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar"))

再加载:

print("start loading cls model")
model=MobileNetV2(n_class=2)
if os.path.isfile(args.load_path):
    state_dict=torch.load(args.load_path)
    print(state_dict['epoch'])#获取保存的参数 对应key值的参数
    print(state_dict['epoch_acc'])
    params=state_dict["model_state_dict"] 
    for param_tensor in params:#打印参数信息
         print(param_tensor,"\t",params[param_tensor].size())
    model.load_state_dict(params)
    print("load cls model successfully")

填坑

这段保存模型参数的代码

torch.save({"epoch":epoch, #一共训练的epoch
                   "model_state_dict":model.module.state_dict(), #保存模型参数
                   'epoch_acc': epoch_acc, #一共训练的epoch
                   "optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
                   ,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar")) 

torch.save({"epoch":epoch, #一共训练的epoch
                   "model_state_dict":model.state_dict(), #保存模型参数
                   'epoch_acc': epoch_acc, #一共训练的epoch
                   "optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
                   ,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar"))

与这段的不同在于model.module.state_dict()与model.state_dict()的区别
现在来打印一下

model=MobileNetV2(n_class=2)#加载模型结构
model_dict =  model.state_dict()#获取模型参数(未加载保存的模型参数 )
model_dict----------model.module.state_dict()---------model.state_dict()三者参数的对应的名称(这里只打印几个)
model_dict:
features.0.0.weight 	 torch.Size([32, 3, 3, 3])
features.0.1.weight 	 torch.Size([32])
features.0.1.bias 	 torch.Size([32])
features.0.1.running_mean 	 torch.Size([32])
features.0.1.running_var 	 torch.Size([32])
features.0.1.num_batches_tracked 	 torch.Size([])

model.module.state_dict():
features.0.0.weight 	 torch.Size([32, 3, 3, 3])
features.0.1.weight 	 torch.Size([32])
features.0.1.bias 	 torch.Size([32])
features.0.1.running_mean 	 torch.Size([32])
features.0.1.running_var 	 torch.Size([32])
features.0.1.num_batches_tracked 	 torch.Size([])

model.state_dict():
module.features.0.0.weight 	 torch.Size([32, 3, 3, 3])
module.features.0.1.weight 	 torch.Size([32])
module.features.0.1.bias 	 torch.Size([32])
module.features.0.1.running_mean 	 torch.Size([32])
module.features.0.1.running_var 	 torch.Size([32])
module.features.0.1.num_batches_tracked 	 torch.Size([])

用多gpus进行训练后直接用model.state_dict()进行保存的模型,每个层参数的名称前面会加上module,这时候再用单卡 gpu model_dict加载model.state_dict()参数时会出现名称不匹配的情况。
因此保存模型时注意使用model.module.state_dict():

总结

1.多gpus训练 用model.state_dict() 保存前面会加上网络参数名称前会加上 module
2.单gpus加载模型,需要去掉网络参数名称前加上的module
两种方法:
(1) 用model.module.state_dict()保存
(2) 去掉网络参数名称前会加上的module再加载模型
3.推荐多gpus训练使用model.module.state_dict()保存,然后单gpu加载,
此时如果还需要多gpu训练可以在加载模型参数后使用torch.nn.DataParallel进行训练