# 文章目录
- 0 项目场景
- 1 模型参数
- 1.1 保存
- 1.2 加载
- 2 整个模型
- 2.1 保存
- 2.2 加载
- 3 断点续训
- 3.1 保存
- 3.2 加载
- 4 多个模型
- 4.1 保存
- 4.2 加载
- 5. 迁移学习
- 5.1 保存
- 5.2 加载
- 6 关于设备
- 6.1 GPU保存 & CPU加载
- 6.1.1 GPU保存
- 6.1.2 CPU加载
- 6.2 GPU保存 & GPU加载
- 6.2.1 GPU保存
- 6.2.2 GPU加载
- 6.3 CPU保存 & CPU加载
- 6.3.1 CPU保存
- 6.3.2 CPU加载
- 6.4 CPU保存 & GPU加载
- 6.4.1 CPU保存
- 6.4.2 GPU加载
- 7 引用参考
0 项目场景
pytorch训练完模型后,如何保存与加载?保存/加载有两种方式:一是保存/加载模型参数,二是保存/加载整个模型。
1 模型参数
保存/加载模型参数,官方推荐用这种方式,原因也给了:说这种方式对于日后恢复模型更具灵活性。
1.1 保存
torch.save(model.state_dict(), PATH)
state_dict
里保存有模型的参数,PATH
是保存路径,推荐.pt
或.pth
作为文件拓展名。
1.2 加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
TheModelClass
是你定义的模型结构,PATH
是保存路径。如果你的模型结构中含有dropout
或batch normalization
层,在测试之前一定要加上model.eval()
(如果没有可以不加),不然会产生错误的输出结果。
2 整个模型
保存/加载整个模型,官方不推荐这种方式,原因也给了:说是在其它项目中使用或重构后,代码可能会中断。
2.1 保存
torch.save(model, PATH)
2.2 加载
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
定义模型结构的类必须在代码中出现。这种保存/加载模型的方式从语法上来说更加简洁和直观,但是将模型引入其它项目中使用可能出错,所以只在自己的项目中使用应该没有问题,想将模型引入其它项目中使用还是推荐第一种保存/加载方式。
3 断点续训
顾名思义就是从上次没训练完的地方继续训练,这对高效训练来说具有重要意义。
3.1 保存
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
其中model.state_dict()
和optimizer.state_dict()
是必须要保存的,因为这两项会随着模型的训练而更新。epoch
和loss
等是作为记录用的,能让你直观的了解到目前训练到第几轮了,损失是多少。PATH
是保存路径,建议以.tar
为文件拓展名。
3.2 加载
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
# - or -
model.eval()
首先初始化模型和优化器,然后加载之前保存的模型和优化器参数。接着你可以选择从上一次结束的地方继续训练或者直接测试。继续训练的话加上model.train()
,测试模型的话加上model.eval()
,如果模型结构中没有dropout
或batch normalization
层,可以不加。
4 多个模型
有时候你可能需要将多个模型保存到一个文件中,比如GAN
4.1 保存
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
4.2 加载
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.train()
modelB.train()
# - or -
modelA.eval()
modelB.eval()
5. 迁移学习
有时候我们在训练一个新的模型B时可以用到已有的模型A的参数,比如迁移学习,这样就不用从头开始训了,模型可以很快的收敛,大大地提高了训练效率。
5.1 保存
torch.save(modelA.state_dict(), PATH)
5.2 加载
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
strict=False
:模型A和模型B是不完全一样的,模型B训练的时候可能只需要A中一部分值,其它不要的值就丢掉,设置strict=False
就是为了匹配需要的那部分值,忽略不需要的那部分值。
6 关于设备
如何在不同的设备,比如CPU或GPU上,保存与加载模型?
6.1 GPU保存 & CPU加载
模型在GPU上训练,但想把它加载到CPU上时,用这种方式
6.1.1 GPU保存
torch.save(model.state_dict(), PATH)
6.1.2 CPU加载
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
6.2 GPU保存 & GPU加载
模型在GPU上训练,想把它加载到GPU上时,用这种方式
6.2.1 GPU保存
torch.save(model.state_dict(), PATH)
6.2.2 GPU加载
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
6.3 CPU保存 & CPU加载
模型在CPU上训练,想把它加载到CPU上时,用这种方式
6.3.1 CPU保存
torch.save(model.state_dict(), PATH)
6.3.2 CPU加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
6.4 CPU保存 & GPU加载
模型在CPU上训练,想把它加载到GPU上时,用这种方式
6.4.1 CPU保存
torch.save(model.state_dict(), PATH)
6.4.2 GPU加载
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
7 引用参考
https://pytorch.org/tutorials/beginner/saving_loading_models.html