搭建神经网络以前还需要载入、构建数据。PyTorch 提供了非常方便的模块 torch.utils.data 来完成相关的任务。

1. 总览

构建一个可以被 PyTorch 利用的数据集分两步:

  1. 划分数据集、数据采样器(可选),构建 PyTorch 数据集(可选)
  2. 构建数据集的读取器

PyTorch 支持下面两种数据集:

  • map-style datasets(映射风格的数据集)
    需要重写 __getitem__()__len__() 两个方法。
  • iterable-style datasets(遍历风格的数据集)
    IterableDataset 的子类,重写了 __iter__() 方法。

2. 划分数据集

这里主要有三个函数:

  • torch.utils.data.ConcatDataset
  • torch.utils.data.Subset
  • torch.utils.data.random_split

这里仅仅讨论一下 torch.utils.data.random_split

torch.utils.data.random_split(dataset, lengths)

类似于 sklearn 里的 train_test_split,输入一个目标数据集和划分长度的列表。

# 
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

3. 数据采样器 torch.utils.data.Sampler

对于 iterable-style datasets,数据读取的顺序完全取决于用户定义的读取顺序。

而对于 map-style datasets,使用 torch.utils.data.Sampler 来指定数据读取过程中的索引/值的顺序。可以在 DataLoader 中指定 shuffle 来指定顺序读取或随机读取, 也可以指定 sampler 来定制读取。

目前 PyTorch 中已有的 sampler 有:

torch.utils.data.SequentialSampler(data_source)
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
torch.utils.data.SubsetRandomSampler(indices)
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

它们都继承于 torch.utils.data.Sampler。注意,这些类返回的都是索引值。这里以 torch.utils.data.SequentialSampler 为例。

>>> a = [1,5,78,9,68]
>>> b = torch.utils.data.SequentialSampler(a)
>>> for x in b:
...     print(x)
0
1
2
3
4

再来说一个比较有意思的函数,torch.utils.data.WeightedRandomSampler。它根据权重来随机选择数据。

# 
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
>>> weights = [2 if label == 1 else 1 for data, label in dataset]
>>> print(weights)
[2, 2, 1, 1, 2, 1, 1, 2]
>>> sampler = WeightedRandomSampler(weights, num_samples=9, replacement=True)
>>> dataloader = DataLoader(dataset, batch_size=3, sampler=sampler)
...

可以看到,WeightedRandomSampler 生成的实例会作为 DataLoader 的一个变量。

另外,BatchSampler 与其他 Sampler 的主要区别是它需要将 Sampler 作为参数进行打包,进而每次迭代返回以 batch size 为大小的 index 列表。也就是说在后面的读取数据过程中使用的都是 batch sampler。

3. 数据集构造器 torch.utils.data.Dataset

使用 torch.utils.data.Dataset 需要将其继承,并重写 __getitem__()__len__() 两个方法:

class NewDataSet(DataSet):
	def __init__(self, x):
		...
	def __getitem__(self, index):
		return self.x[index]
	def __len__(self):
		return len(self.x)

这样就根据已有的数据集构建了一个 PyTorch 使用的数据集。

4. 数据集读取器 torch.utils.data.DataLoader

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

这个类的常用属性有:

  • dataset:目标数据集,应该是一个 torch.utils.data.Dataset 实例;
  • batch_size:每一批处理的个数;
  • shuffle:每一个 epoch 内数据是否打乱顺序;
  • sampler:指定采样器获得索引,与 shuffle 互斥;
  • batch_sampler:见采样器部分的 BatchSampler。与 batch_sizeshufflesamplerdrop_last 互斥。
  • num_workers:读取数据使用的线程数;
  • drop_last:如果最后一组数据不足 batch_size 个,是否保留这个 batch。

总结一下,大体流程可以是:torch.utils.data.Dataset 构造数据集 -> 划分数据集(可选) -> 给数据集添加索引(可选) -> 使用 torch.utils.data.DataLoader 分批喂给训练模型。生成索引既可以显式添加索引,也可以隐式在 torch.utils.data.DataLoader 自动完成。