# 文章目录

  • 0 项目场景
  • 1 模型参数
  • 1.1 保存
  • 1.2 加载
  • 2 整个模型
  • 2.1 保存
  • 2.2 加载
  • 3 断点续训
  • 3.1 保存
  • 3.2 加载
  • 4 多个模型
  • 4.1 保存
  • 4.2 加载
  • 5. 迁移学习
  • 5.1 保存
  • 5.2 加载
  • 6 关于设备
  • 6.1 GPU保存 & CPU加载
  • 6.1.1 GPU保存
  • 6.1.2 CPU加载
  • 6.2 GPU保存 & GPU加载
  • 6.2.1 GPU保存
  • 6.2.2 GPU加载
  • 6.3 CPU保存 & CPU加载
  • 6.3.1 CPU保存
  • 6.3.2 CPU加载
  • 6.4 CPU保存 & GPU加载
  • 6.4.1 CPU保存
  • 6.4.2 GPU加载
  • 7 引用参考

0 项目场景

pytorch训练完模型后,如何保存与加载?保存/加载有两种方式:一是保存/加载模型参数,二是保存/加载整个模型。

1 模型参数

保存/加载模型参数,官方推荐用这种方式,原因也给了:说这种方式对于日后恢复模型更具灵活性。

1.1 保存

torch.save(model.state_dict(), PATH)

state_dict里保存有模型的参数,PATH是保存路径,推荐.pt.pth作为文件拓展名。

1.2 加载

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

TheModelClass是你定义的模型结构,PATH是保存路径。如果你的模型结构中含有dropoutbatch normalization层,在测试之前一定要加上model.eval()(如果没有可以不加),不然会产生错误的输出结果。

2 整个模型

保存/加载整个模型,官方不推荐这种方式,原因也给了:说是在其它项目中使用或重构后,代码可能会中断。

2.1 保存

torch.save(model, PATH)

2.2 加载

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

定义模型结构的类必须在代码中出现。这种保存/加载模型的方式从语法上来说更加简洁和直观,但是将模型引入其它项目中使用可能出错,所以只在自己的项目中使用应该没有问题,想将模型引入其它项目中使用还是推荐第一种保存/加载方式。

3 断点续训

顾名思义就是从上次没训练完的地方继续训练,这对高效训练来说具有重要意义。

3.1 保存

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    ...
}, PATH)

其中model.state_dict()optimizer.state_dict()是必须要保存的,因为这两项会随着模型的训练而更新。epochloss等是作为记录用的,能让你直观的了解到目前训练到第几轮了,损失是多少。PATH是保存路径,建议以.tar为文件拓展名。

3.2 加载

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.train()
# - or -
model.eval()

首先初始化模型和优化器,然后加载之前保存的模型和优化器参数。接着你可以选择从上一次结束的地方继续训练或者直接测试。继续训练的话加上model.train(),测试模型的话加上model.eval(),如果模型结构中没有dropoutbatch normalization层,可以不加。

4 多个模型

有时候你可能需要将多个模型保存到一个文件中,比如GAN

4.1 保存

torch.save({
    'modelA_state_dict': modelA.state_dict(),
    'modelB_state_dict': modelB.state_dict(),
    'optimizerA_state_dict': optimizerA.state_dict(),
    'optimizerB_state_dict': optimizerB.state_dict(),
    ...
}, PATH)

4.2 加载

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.train()
modelB.train()
# - or -
modelA.eval()
modelB.eval()

5. 迁移学习

有时候我们在训练一个新的模型B时可以用到已有的模型A的参数,比如迁移学习,这样就不用从头开始训了,模型可以很快的收敛,大大地提高了训练效率。

5.1 保存

torch.save(modelA.state_dict(), PATH)

5.2 加载

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

strict=False:模型A和模型B是不完全一样的,模型B训练的时候可能只需要A中一部分值,其它不要的值就丢掉,设置strict=False就是为了匹配需要的那部分值,忽略不需要的那部分值。

6 关于设备

如何在不同的设备,比如CPU或GPU上,保存与加载模型?

6.1 GPU保存 & CPU加载

模型在GPU上训练,但想把它加载到CPU上时,用这种方式

6.1.1 GPU保存

torch.save(model.state_dict(), PATH)

6.1.2 CPU加载

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

6.2 GPU保存 & GPU加载

模型在GPU上训练,想把它加载到GPU上时,用这种方式

6.2.1 GPU保存

torch.save(model.state_dict(), PATH)

6.2.2 GPU加载

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

6.3 CPU保存 & CPU加载

模型在CPU上训练,想把它加载到CPU上时,用这种方式

6.3.1 CPU保存

torch.save(model.state_dict(), PATH)

6.3.2 CPU加载

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

6.4 CPU保存 & GPU加载

模型在CPU上训练,想把它加载到GPU上时,用这种方式

6.4.1 CPU保存

torch.save(model.state_dict(), PATH)

6.4.2 GPU加载

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

7 引用参考

https://pytorch.org/tutorials/beginner/saving_loading_models.html