torch.utils.data.DataLoader

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

官方文档的链接 它是PyTorch中数据读取的一个重要接口,该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

DataLoader的所有参数如上所示,接下来依次对每个参数介绍

Dataset Types

DataLoader构造函数最重要的参数是dataset,它表示要从中加载数据的dataset对象。PyTorch支持两种不同类型的数据集:

  • map-style datasets
    这种类型的数据集可以实现 __getitem__() and __len__()方法,并且表示为从indices/keys到数据样本的映射。例如,当使用dataset[idx]访问这样的数据集时,可以从磁盘上的文件夹中读取idx-th图像及其对应的标签。
  • iterable-style datasets.
    该类数据集是IterableDataset子类的一个实例,可以实现__iter__()方法,表示对数据样本进行迭代。这种类型的数据集特别适合这样的情况,即随机读取非常昂贵,甚至是不可能的,并且batch size取决于获取的数据。例如,当iter(dataset)时,可以返回从数据库、远程服务器甚至实时生成的日志读取的数据流。

batch_size (python:int, optional)

每批加载多少个样本(默认值:1)。

shuffle (bool, optional)

设置为True,以便在每个epoch重新洗牌数据(默认为False)。

sampler (Sampler, optional)

定义从数据集提取样本的策略。如果指定,则shuffle必须为False。

batch_sampler (Sampler, optional)

类似于sampler,但一次返回一批索引。与batch_size, shuffle, sampler, and drop_last.相互排斥

num_workers (python:int, optional)

要使用多少子进程来加载数据。0表示将在主进程中加载数据。(默认值:0)。

collate_fn (callable, optional)

将一个样本列表合并成一个张量的小批量。当使用批量加载从map样式的数据集时使用。

pin_memory (bool, optional)

如果是,数据加载器将把张量复制到CUDA固定内存中,然后再返回它们。

drop_last (bool, optional)

如果数据集大小不能被批处理大小整除,则将其设置为True以删除最后一个未完成的批处理。如果为False且数据集的大小不能被批处理大小整除,则最后一批数据将更小。(默认值:False)

timeout (numeric, optional)

超时,默认为0。是用来设置数据读取的超时时间的,超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。

worker_init_fn (callable, optional)