搭建神经网络以前还需要载入、构建数据。PyTorch 提供了非常方便的模块 torch.utils.data
来完成相关的任务。
1. 总览
构建一个可以被 PyTorch 利用的数据集分两步:
- 划分数据集、数据采样器(可选),构建 PyTorch 数据集(可选)
- 构建数据集的读取器
PyTorch 支持下面两种数据集:
- map-style datasets(映射风格的数据集)
需要重写__getitem__()
和__len__()
两个方法。 - iterable-style datasets(遍历风格的数据集)
是IterableDataset
的子类,重写了__iter__()
方法。
2. 划分数据集
这里主要有三个函数:
torch.utils.data.ConcatDataset
torch.utils.data.Subset
torch.utils.data.random_split
这里仅仅讨论一下 torch.utils.data.random_split
。
torch.utils.data.random_split(dataset, lengths)
类似于 sklearn
里的 train_test_split
,输入一个目标数据集和划分长度的列表。
#
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
3. 数据采样器 torch.utils.data.Sampler
对于 iterable-style datasets,数据读取的顺序完全取决于用户定义的读取顺序。
而对于 map-style datasets,使用 torch.utils.data.Sampler
来指定数据读取过程中的索引/值的顺序。可以在 DataLoader
中指定 shuffle
来指定顺序读取或随机读取, 也可以指定 sampler
来定制读取。
目前 PyTorch 中已有的 sampler 有:
torch.utils.data.SequentialSampler(data_source)
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
torch.utils.data.SubsetRandomSampler(indices)
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
它们都继承于 torch.utils.data.Sampler
。注意,这些类返回的都是索引值。这里以 torch.utils.data.SequentialSampler
为例。
>>> a = [1,5,78,9,68]
>>> b = torch.utils.data.SequentialSampler(a)
>>> for x in b:
... print(x)
0
1
2
3
4
再来说一个比较有意思的函数,torch.utils.data.WeightedRandomSampler
。它根据权重来随机选择数据。
#
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
>>> weights = [2 if label == 1 else 1 for data, label in dataset]
>>> print(weights)
[2, 2, 1, 1, 2, 1, 1, 2]
>>> sampler = WeightedRandomSampler(weights, num_samples=9, replacement=True)
>>> dataloader = DataLoader(dataset, batch_size=3, sampler=sampler)
...
可以看到,WeightedRandomSampler
生成的实例会作为 DataLoader
的一个变量。
另外,BatchSampler
与其他 Sampler 的主要区别是它需要将 Sampler 作为参数进行打包,进而每次迭代返回以 batch size 为大小的 index 列表。也就是说在后面的读取数据过程中使用的都是 batch sampler。
3. 数据集构造器 torch.utils.data.Dataset
使用 torch.utils.data.Dataset
需要将其继承,并重写 __getitem__()
和 __len__()
两个方法:
class NewDataSet(DataSet):
def __init__(self, x):
...
def __getitem__(self, index):
return self.x[index]
def __len__(self):
return len(self.x)
这样就根据已有的数据集构建了一个 PyTorch 使用的数据集。
4. 数据集读取器 torch.utils.data.DataLoader
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
这个类的常用属性有:
-
dataset
:目标数据集,应该是一个torch.utils.data.Dataset
实例; -
batch_size
:每一批处理的个数; -
shuffle
:每一个 epoch 内数据是否打乱顺序; -
sampler
:指定采样器获得索引,与shuffle
互斥; -
batch_sampler
:见采样器部分的BatchSampler
。与batch_size
,shuffle
,sampler
和drop_last
互斥。 -
num_workers
:读取数据使用的线程数; -
drop_last
:如果最后一组数据不足 batch_size 个,是否保留这个 batch。
总结一下,大体流程可以是:torch.utils.data.Dataset
构造数据集 -> 划分数据集(可选) -> 给数据集添加索引(可选) -> 使用 torch.utils.data.DataLoader
分批喂给训练模型。生成索引既可以显式添加索引,也可以隐式在 torch.utils.data.DataLoader
自动完成。