代码复现需要考虑:

1.随机种子设置;

2.DataLoader设置;

3.CUDA算法随机性;

4.随机数生成器的调用细节;

5.多卡问题。

前半部分来自:https://zhuanlan.zhihu.com/p/532511514

我在刚接触pytorch的时候搜到了这个大佬的文章,当时最后天坑部分没有看的太明白,直到今天我也遇到的相同的问题,特来做一点点补充,方便大家理解。

随机种子

Pytorch复现的入门版本就是官方指南,需要设定好各种随机种子。

https://pytorch.org/docs/stable/notes/randomness.html

import random
import numpy as np
import torch

random.seed(0)  # Python 随机种子
np.random.seed(0)  # Numpy 随机种子
torch.manual_seed(0)  # Pytorch 随机种子
torch.cuda.manual_seed(0)  # CUDA 随机种子
torch.cuda.manual_seed_all(0)  # CUDA 随机种子
2. Dataloader的并行

DataLoader启用多线程时(并行的线程数num_workers 大于1)也会出现随机现象,解决办法:

1. 禁用多线程:num_workers 设置为0。

2. 固定好worker的初始化方式,代码如下:

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(0)

    DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        worker_init_fn=seed_worker,
        generator=g,
    )

不过我自己的代码中,没有对DataLoader进行特殊处理,代码也可以复现。

3.算法的随机性

有些并行算法带有随机性,比如LSTM或者注意力机制,RNN等。

尤其是使用 CUDA Toolkit 10.2 或更高版本构建 cuDNN 库时,cuBLAS 库中新的缓冲区管理和启发式算法会带来随机性。在默认配置中使用两种缓冲区大小(16 KB 和 4 MB)时会发生这种情况。

解决办法就是在代码头部设置环境变量:

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

如果是用到CNN的算法,同时要设置以下变量:

torch.backends.cudnn.benchmark = False  # 限制cuDNN算法选择的不确定性
torch.backends.cudnn.deterministic=True  # 固定cuDNN算法

设置完这些,基本99%的情况下都可以复现结果,如果无法复现,那就重启notebook 或者python。

4.随机数生成器细节

如果在一个for 循环内多次运行pytorch训练,就会出现随机性。以下常见方式均无效:强制每次train之前empty_cache/每次循环结束后手动del 变量并且用gc 回收/强制初始化模型的参数/强制设置set_rng_state/重启python文件和notebook;

知乎大佬的解决方案:1.在for 循环的内部设置随机种子。2.nn模型里面的dropout 在for 循环里面有随机性。最好不用或者显式的调用Dropout。

本文作者的实验结果:

需要注意随机数生成器的调用顺序!

例如:调用一次Dataloader就会影响下一个Dataloader的随机数生成。

解释:例如现在有两种模型的训练方式:

  1. 在train后面继续进行下一个Epoch的train。
  2. train后面进行val,再进行下一个Epoch的train。

这两种方式得到的训练结果从第二个Epoch开始就是不同的,且val前后模型的weights没变,那应该就是生成的随机数变了。

另外:在模型代码前加入如下语句,也会改变模型训练的结果,猜测还是调用到pytorch的随机数生成器了,导致后续代码的随机部分也随之改变。

from frame.loss import FocalLoss, LabelSmoothCrossEntropyLoss, TimeWeightedCELoss

这篇文章说的也是这个意思:https://zhuanlan.zhihu.com/p/352833875

5.多卡问题

如果多卡无法精准复现,可以尝试使用单卡。暂时没有找到解决方案,欢迎大佬们留言。

总结

以上就是Pytorch代码的复现终极指南,保险起见的话,先把能加的都加上,然后看能否复现。