pytorch断点续传

  • 前言
  • 一、断点续传的作用?
  • 二、具体步骤
  • 1.保存断点
  • 2.加载断点
  • 三、其他需注意的地方

前言

当在模型训练过程中遇到下面的情况时我们就需要断点续传的技巧了

  • 本地训练到一半但由于有其他事情或事故必须主动或被动中断正在训练的模型等待后续再继续训练
  • 云端训练模型时由于平台的不稳定性导致训练中断,例如colab等。

一、断点续传的作用?

断点续传会在模型训练到一定时期时保存一次当前训练的数据,保存下的数据是以字典的形式序列化存储的,后续再通过pytorch反序列化读取即可。

二、具体步骤

1.保存断点

首先需要设置一个保存周期变量checkpoint_interval,具体的值可以自定义,值过小的话保存次数过多训练时间就会增强,过大就容易导致马上就要达到一个保存周期时训练中断,整个周期几乎是重新训练。具体代码如下:

checkpoint_interval = 3
    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint = {"model_state_dict": model.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

这里存储了模型的参数,优化器的参数以及当前epoch。存储优化器的参数是因为优化器中存储了当前权重更新状态的相关参数。

2.加载断点

这里设置了一个start_epoch,它的值来自断点中存储的epoch值,代表当前要继续的epoch值。resume是个布尔值,代表是否继续训练,若要继续训练则手动设置为True。另外在读取断点时还设置了schedulerepoch,这是因为现在的scheduler的更新策略往往跟当前的epoch是有关系的,例如随着epoch的增加学习率的梯度越来越小。

start_epoch = 0
resume = False
path_checkpoint = "checkpointfirst_7_epoch.pkl"#断点路径
if resume:
	checkpoint = torch.load(path_checkpoint)#加载断点
    model.load_state_dict(checkpoint['model_state_dict'])#加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])#加载优化器参数
    start_epoch = checkpoint['epoch']#设置开始的epoch
    scheduler.last_epoch = start_epoch#设置学习率的last_epoch

三、其他需注意的地方

  • 在我们使用预训练模型微调时,会先将预训练模型的前几层冻结着重训练后面自己添加的层。例如使用resnet101模型做微调时,先将前5层进行冻结,只训练最后一层全连接层。冻结时会使用下面的代码(这里以resnet101举例):
for child in model.children():
    ct += 1
    # print(ct,child)
    if ct < 5:
        for param in child.parameters():
            param.requires_grad = False

但当模型的参数存入断点文件时,是不会存储参数requires_grad 的。因此若要设置某些层不更新参数则需要在读取断点后执行相应设置。这样无论是重新训练还是继续训练某些层的requires_grad 都是满足需求的。

  • 即使在保存断点时这些参数是在GPU上,读取时仍然默认在CPU中,因此还是需要添加model = model.to(device)

无论是上述哪个点,都只要把相应操作的代码放在模型加载之后即可