文章用于记录一下本人对于较大数据集加载的问题的一些解决办法和思考。(比较口水话)

本文解决的问题适用领域:

  1. 数据集很大,无法一次加载到内存
  2. 纯文本类型的数据

我之前训练某个特定任务,习惯于把数据提前预处理为dataset保存起来,然后每次训练的时候直接加载这个文件。我这样做的目的是,方便调代码,使用很小量的数据先把代码调通,当出现一些小问题时不至于数据处理很久。

但是最近在使用自己的一些语料微调一个语言模型, 语料纯文本大小大概在8G左右,使用预处理脚本CPU满速处理后使用torch.save()保存成二进制文件,大概也8G左右。

这个时候问题来了,我使用了8张卡跑代码,这样的话就是8个进程,每个进程都要torch.load()加载这个二进制文件,256G的内存直接爆掉了,留下我一脸懵逼。

好,遇到了问题,开始动手解决。

内存不足?那多少够用呢?我先使用python窗口,单进程加载了这个文件,内存使用了60+G,算下来的话8卡是要500G的内存起步,不太现实,那么就从代码和数据的角度来解决吧。

使用原代码的处理方法?他处理后的存储sample应该也会占用60G内存,并且前期特别耗时,大概要四五小时,一旦发现后边的代码出了问题,相当于白处理了。

然后百度查相关的文档,看看有没有大佬解决这个问题。

有人说把数据分成若干份,每加载一个文件train起来,但是作者说这样做,内存依然会一直上升。有人说分块读取,但是这么大的数据量一直放在内存里,占用应该和前边一样也是特别大的,不一定够。

想来想去还是按照torch提供的dataset来处理,其实这是使用最广泛的方法,在getitem方法里写上处理逻辑就行了,返回相应的一条数据(sample)。但是现在的问题是,我读取了10行纯文本,每一行可能都会处理成若干条sample,如何优雅地使用getitem来返回合适的数据呢?

答案是:生成器。边生成边训练,没有某些数据常驻内存,内存占用就可以下去了。

关于生成器是啥自行搜索。废话不多说,上demo 代码:
【update】使用了生成器之后,dataloader估计不能多线程去取数据了,如果设置了num_workers > 1可能反而降低效率。

class GenDataset(torch.utils.data.Dataset):
    def __init__(self, data_path,):
     
        with open(data_path, "r") as f:
            self.data = f.readlines()
            # 如果这里都爆内存的话,
            # 看起来只能使用文件指针,在getitem里边逐行读取了
            # 得到的data是 list[str]
        random.shuffle(self.data)
        self.data_gen = self.get_data()
    
    def get_data(self):

        for doc in self.data:
            # 每个doc是一行文本,可能因为过长处理成为多个samples
            batch_samples = []
            # 巴拉巴拉
            # 经过处理得到了batch_samples

            while len(batch_samples) > 0:
                # 逐个把数据返回,每次只返回一条
                yield batch_samples.pop()

    def __len__(self):
        # 这里返回长度是用于tqdm进度条显示用的
        # 我这里乘以4是我之前预处理的时候看得到总量大概是文档数目的4倍
        # 你也可以设定一个很大的数字,当dataloader提取不到数据的时候就会停止
        return len(self.data  * 4) 

    def __getitem__(self, idx):
        # 每次使用next函数返回生成器生成的一条数据,此处的idx用不到了
        return next(self.data_gen)

Contact me : jianshu[AT]std.uestc.edu.cn