Pytorch dataloader加载数据卡住的解决方案
- 问题描述
- Dataloader运行原理
- 原因一: num_workers设置过大,导致内存爆炸
- 原因二: num_workers设置过小,CPU对GPU供不应求
- 原因三: 使用了 iterable-style datasets
- 原因四:python对象在多进程中的COW
- 原因五:自定义的Dataset中出现了“坏数据”
问题描述
通过Dataloader加载自定义的Dataset类时,每次运行到某个数据的时候程序会卡住,输出如下
Dataloader运行原理
Dataset 类一次只能检索一个样本及其标签,Dataloader 则可以让我们从 Dataset 中一次获取一个 minibatch 的样本,在每个 epoch 打乱样本数据以减少模型过拟合,还可以通过多进程加速处理。
Dataloader 开始工作时会一次性创建 num_workers 个进程,通过batch_sampler(一次性返回一个batch的索引)指定 batch 给 worker,这些 worker 就负责把分配到的batch加载进RAM;Dataloader 在每次迭代时就可以直接从 RAM 中寻找到需要的batch。
原因一: num_workers设置过大,导致内存爆炸
num_workers 的数量不要超过 cpu 的逻辑个数,可以通过指令 cat /proc/cpuinfo| grep “processor”| wc -l 查看自己 cpu 的逻辑个数。
也可以尝试在每轮迭代结束后使用 torch.cuda.empty_cache() 。
原因二: num_workers设置过小,CPU对GPU供不应求
将 Dataloader 中的 pin_memory 设为 True,注意此时 dataset 的数据类型只能是 tensors,maps 或包含 tensor 的可迭代对象。
原因三: 使用了 iterable-style datasets
通常我们自己定义的 Dataset 都是 Map-style datasets,继承了 torch.utils.data.Dataset 并实现了 _ _ getitem _ _() 和 _ _ len _ _ () 方法。
如果要处理的数据难以随机访问(如流数据)时,或需要动态 batch size时,通常使用继承了 torch.utils.data.IterableDataset 的 Iterable-style Dataset 并实现 _ _ iter _ _()方法。由于该方法返回的是包含样本的迭代器,搭配 Dataloader 和多进程(即 num_workers > 0)使用时,每个进程都会拥有一个 dataset 对象的拷贝,这就使大量重复的数据造成内存挤占。
可以在 _ _ iter _ _() 方法中使用 get_worker_info() 来单独管理每个进程。
原因四:python对象在多进程中的COW
多进程编程时,所有子进程只是创建了一个指向主进程的内存映射,只有子进程修改内容时,才会获得一份属于自己的私有拷贝,这就是 linux 内核的 copy-on-write 机制。值得注意的是, Python 内置的可变类型对象 (如 list,dict) 附带一个 reference counter(引用计数器)。可以用 sys.getrefcount(variable) 语句查看对象的引用个数。
如果 Dataset 中涉及很多此类的数据,那么在迭代访问的过程中会由于 python in-build object 计数器的更改触发 copy-on-write,每个 worker 进程都会随着迭代次数的增加占据越来越多的内存,导致内存溢出。因此我们可以把这些 objects 换成 non-refcounted 对象(如 pandas, numpy, pyarrow, torch tensor),避免不必要的拷贝。
- 修改 list 对象的数据类型, 以输入 Dataset 的文件列表 filenames 为例:
from torch.utils.data import Dataset
from cv2 import imread
class MyDataset(Dataset):
def __init__(self, filenames_byte):
self.filenames = filenames_byte
def __getitem__(self, idx):
filename = str(self.filenames[idx], encoding='utf-8')
# or
# filename = self.filenames[idx].astype(str)
return imread(filename)
filenames = ['path1', 'path2']
filenames_byte = np.array(filenames).astype(np.string_) # dtype='|S5'
myDataset = MyDataset(filenames_byte)
- 用 Manager 对象来管理 list
from torch.utils.data import Dataset
from multiprocessing import Manager
from cv2 import imread
class MyDataset(Dataset):
def __init__(self, filenames):
manager = Manager()
self.filenames = manager.list(filenames)
def __getitem__(self, idx):
filename = self.filenames[idx]
return imread(filename)
def __len__(self):
return len(self.data)
原因五:自定义的Dataset中出现了“坏数据”
如加载的数据损坏,数据类型管理混乱造成 for 循环时间过长,都有可能导致 dataloader “卡住”的结果,需要逐步输出排查异常。
参考链接:
【1】https://pytorch.org/docs/stable/data.html
【2】https://icode.best/i/44561339126672
【3】https://zhuanlan.zhihu.com/p/366595260