Pytorch dataloader加载数据卡住的解决方案

  • 问题描述
  • Dataloader运行原理
  • 原因一: num_workers设置过大,导致内存爆炸
  • 原因二: num_workers设置过小,CPU对GPU供不应求
  • 原因三: 使用了 iterable-style datasets
  • 原因四:python对象在多进程中的COW
  • 原因五:自定义的Dataset中出现了“坏数据”

问题描述

通过Dataloader加载自定义的Dataset类时,每次运行到某个数据的时候程序会卡住,输出如下

pytorch dataloader 内存不断上升 pytorch dataloader爆内存_pytorch

Dataloader运行原理

Dataset 类一次只能检索一个样本及其标签,Dataloader 则可以让我们从 Dataset 中一次获取一个 minibatch 的样本,在每个 epoch 打乱样本数据以减少模型过拟合,还可以通过多进程加速处理。

Dataloader 开始工作时会一次性创建 num_workers 个进程,通过batch_sampler(一次性返回一个batch的索引)指定 batch 给 worker,这些 worker 就负责把分配到的batch加载进RAM;Dataloader 在每次迭代时就可以直接从 RAM 中寻找到需要的batch。

原因一: num_workers设置过大,导致内存爆炸

num_workers 的数量不要超过 cpu 的逻辑个数,可以通过指令 cat /proc/cpuinfo| grep “processor”| wc -l 查看自己 cpu 的逻辑个数。
也可以尝试在每轮迭代结束后使用 torch.cuda.empty_cache() 。

原因二: num_workers设置过小,CPU对GPU供不应求

将 Dataloader 中的 pin_memory 设为 True,注意此时 dataset 的数据类型只能是 tensors,maps 或包含 tensor 的可迭代对象。

原因三: 使用了 iterable-style datasets

通常我们自己定义的 Dataset 都是 Map-style datasets,继承了 torch.utils.data.Dataset 并实现了 _ _ getitem _ _() 和 _ _ len _ _ () 方法。

如果要处理的数据难以随机访问(如流数据)时,或需要动态 batch size时,通常使用继承了 torch.utils.data.IterableDataset 的 Iterable-style Dataset 并实现 _ _ iter _ _()方法。由于该方法返回的是包含样本的迭代器,搭配 Dataloader 和多进程(即 num_workers > 0)使用时,每个进程都会拥有一个 dataset 对象的拷贝,这就使大量重复的数据造成内存挤占。

可以在 _ _ iter _ _() 方法中使用 get_worker_info() 来单独管理每个进程。

原因四:python对象在多进程中的COW

多进程编程时,所有子进程只是创建了一个指向主进程的内存映射,只有子进程修改内容时,才会获得一份属于自己的私有拷贝,这就是 linux 内核的 copy-on-write 机制。值得注意的是, Python 内置的可变类型对象 (如 list,dict) 附带一个 reference counter(引用计数器)。可以用 sys.getrefcount(variable) 语句查看对象的引用个数。

如果 Dataset 中涉及很多此类的数据,那么在迭代访问的过程中会由于 python in-build object 计数器的更改触发 copy-on-write,每个 worker 进程都会随着迭代次数的增加占据越来越多的内存,导致内存溢出。因此我们可以把这些 objects 换成 non-refcounted 对象(如 pandas, numpy, pyarrow, torch tensor),避免不必要的拷贝。

  1. 修改 list 对象的数据类型, 以输入 Dataset 的文件列表 filenames 为例:
from torch.utils.data import Dataset
from cv2 import imread

class MyDataset(Dataset):
    def __init__(self, filenames_byte):
        self.filenames = filenames_byte

    def __getitem__(self, idx):
    	filename = str(self.filenames[idx], encoding='utf-8')
    	# or
		# filename = self.filenames[idx].astype(str)
        return imread(filename)
        
filenames = ['path1', 'path2']
filenames_byte = np.array(filenames).astype(np.string_)  # dtype='|S5'
myDataset = MyDataset(filenames_byte)
  1. 用 Manager 对象来管理 list
from torch.utils.data import Dataset
from multiprocessing import Manager
from cv2 import imread

class MyDataset(Dataset):
    def __init__(self, filenames):
        manager = Manager()
        self.filenames = manager.list(filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        return imread(filename)
    
    def __len__(self):
        return len(self.data)

原因五:自定义的Dataset中出现了“坏数据”

如加载的数据损坏,数据类型管理混乱造成 for 循环时间过长,都有可能导致 dataloader “卡住”的结果,需要逐步输出排查异常。

参考链接:
【1】https://pytorch.org/docs/stable/data.html
【2】https://icode.best/i/44561339126672
【3】https://zhuanlan.zhihu.com/p/366595260