代码复现需要考虑:
1.随机种子设置;
2.DataLoader设置;
3.CUDA算法随机性;
4.随机数生成器的调用细节;
5.多卡问题。
前半部分来自:https://zhuanlan.zhihu.com/p/532511514
我在刚接触pytorch的时候搜到了这个大佬的文章,当时最后天坑部分没有看的太明白,直到今天我也遇到的相同的问题,特来做一点点补充,方便大家理解。
随机种子
Pytorch复现的入门版本就是官方指南,需要设定好各种随机种子。
https://pytorch.org/docs/stable/notes/randomness.html
import random
import numpy as np
import torch
random.seed(0) # Python 随机种子
np.random.seed(0) # Numpy 随机种子
torch.manual_seed(0) # Pytorch 随机种子
torch.cuda.manual_seed(0) # CUDA 随机种子
torch.cuda.manual_seed_all(0) # CUDA 随机种子
2. Dataloader的并行
DataLoader启用多线程时(并行的线程数num_workers 大于1)也会出现随机现象,解决办法:
1. 禁用多线程:num_workers 设置为0。
2. 固定好worker的初始化方式,代码如下:
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2 ** 32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
worker_init_fn=seed_worker,
generator=g,
)
不过我自己的代码中,没有对DataLoader进行特殊处理,代码也可以复现。
3.算法的随机性
有些并行算法带有随机性,比如LSTM或者注意力机制,RNN等。
尤其是使用 CUDA Toolkit 10.2 或更高版本构建 cuDNN 库时,cuBLAS 库中新的缓冲区管理和启发式算法会带来随机性。在默认配置中使用两种缓冲区大小(16 KB 和 4 MB)时会发生这种情况。
解决办法就是在代码头部设置环境变量:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
如果是用到CNN的算法,同时要设置以下变量:
torch.backends.cudnn.benchmark = False # 限制cuDNN算法选择的不确定性
torch.backends.cudnn.deterministic=True # 固定cuDNN算法
设置完这些,基本99%的情况下都可以复现结果,如果无法复现,那就重启notebook 或者python。
4.随机数生成器细节
如果在一个for 循环内多次运行pytorch训练,就会出现随机性。以下常见方式均无效:强制每次train之前empty_cache/每次循环结束后手动del 变量并且用gc 回收/强制初始化模型的参数/强制设置set_rng_state/重启python文件和notebook;
知乎大佬的解决方案:1.在for 循环的内部设置随机种子。2.nn模型里面的dropout 在for 循环里面有随机性。最好不用或者显式的调用Dropout。
本文作者的实验结果:
需要注意随机数生成器的调用顺序!
例如:调用一次Dataloader就会影响下一个Dataloader的随机数生成。
解释:例如现在有两种模型的训练方式:
- 在train后面继续进行下一个Epoch的train。
- train后面进行val,再进行下一个Epoch的train。
这两种方式得到的训练结果从第二个Epoch开始就是不同的,且val前后模型的weights没变,那应该就是生成的随机数变了。
另外:在模型代码前加入如下语句,也会改变模型训练的结果,猜测还是调用到pytorch的随机数生成器了,导致后续代码的随机部分也随之改变。
from frame.loss import FocalLoss, LabelSmoothCrossEntropyLoss, TimeWeightedCELoss
这篇文章说的也是这个意思:https://zhuanlan.zhihu.com/p/352833875
5.多卡问题
如果多卡无法精准复现,可以尝试使用单卡。暂时没有找到解决方案,欢迎大佬们留言。
总结
以上就是Pytorch代码的复现终极指南,保险起见的话,先把能加的都加上,然后看能否复现。