pytorch的数据读取
pytorch数据读取的核心是torch.utils.data.DataLoader
类,具有以下特性:
- 支持map-style datasets和iterable-style datasets
- 自定义数据读取顺序
- 自动批量化
- 单线程/多线程读取
- 自动内存锁页
1. 整体流程
DataLoader
的参数如下,主要涉及DataSet
、sample
、collate_fn
、pin_memory
。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
pytorch读取数据的整体处理流程如下图:
无论是map-style还是iterable-style dataset整体流程都是
- 先用采样器采样,采样一次得到一个样本的索引(iterable-style dataset无法通过索引取值,所以使用的是一个虚假的采样器,每次生成None)。
- 使用batch_sampler生成长度为batch_size的索引列表(实际是使用sampler采样batch_size次)。
- 使用collate_fn将batch_size长度的列表整理成batch样本(tensor格式)。
2. DataSet Types
pytorch支持两种类型的数据集map-style dataset
和iterable-style dataset
。
Map-style datasets
字典型数据集是指实现了__getitem__()
和__len__()
协议,表示从索引到数据样本的映射。
可以继承抽象类torch.utils.data.DataSet
,并重写__getitem__()
和__len__()
方法。
Note: DataLoader默认构造的采样器返回的都是整数索引,如果dataset的索引不是整数,需要自定义采样器。
Iterable-style datasets
可迭代型数据集是torch.utils.data.IterableDataset
的子类,需要实现__iter__()
协议,表示对数据样本的一轮迭代。
iterable-style dataset类似python的可迭代对象。使用iter()方法会得到一个迭代器,每次调用next()会得到下一个样本。无法使用索引取元素。所以就不能使用采样器采样得到索引,在使用索引得到样本。dataloader的实现中,对于可迭代类型的数据集会使用一个虚假采样器InfiniteConstantSampler。每次调用都返回None。
class _InfiniteConstantSampler(Sampler):
r"""Analogous to ``itertools.repeat(None, None)``.
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
Args:
data_source (Dataset): dataset to sample from
"""
def __init__(self):
super(_InfiniteConstantSampler, self).__init__(None)
def __iter__(self):
while True:
yield None
这个采样器的目的就是为了在batch_sample时控制采样的次数。
3. Sampler
对于IterableDataset来说,数据读取的顺利是由用户定义迭代决定的。回想下python的迭代器,只能通过循环调用next()方法,依次拿到下一个样本。不能改变原有的次序。
对may-style Dataset来说,sampler用来在数据读取时,指定样本索引的顺序。可以指定DataLoader的shuffle参数来指导顺序读取还是乱序读取。如果shuffle=True,会自动构造一个RandomSampler采样器,shuffle=False,会构造SequentialSample采样器。也可以用户自定义一个采样器并使用sample参数指定。自定义采样器每次返回下一个采样的索引。注意采样器返回的都是样本索引,不是样本本身。需要根据索引得到样本。
batch_sampler
如果一个采样器sampler一次返回批量大小的索引列表,那么就叫做batch_sampler。如果指定batch_size和drop_last参数,就会基于sampler(采样器)自动构造一个batch_sampler(批量采样器)。map-style 数据集也可以使用batch_sampler参数指定自定义的批量采样器。
4. collate_fn
collate_fn从字面上看就是整理函数,是对batch_sampler批量采样器返回的长度是batch_size的索引列表进行加工,处理成模型可以使用的batch_size大小的tensor。
这里需要注意采样器sampler/batch_sampler返回的都是样本索引,collate_fn的输入是批量大小的样本列表。所以在传给collate_fn前要根据索引取样本。
如果是may-style数据集,这个操作大概等价于:
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
如果是iterable-style数据集,大概等价于:
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
可以看到iterable-style的索引其实是没用的,只是用来控制采样的个数。同时,发现collate_fn函数接收的参数是样本列表。collate_fn的一个重要功能就是把这个列表加工成pytorch支持的数据格式tensor。通过看pytorch的源码,如果不指定collate_fn,会使用默认的collate_fn函数,这个函数的功能就是将各种类型的数据转化成tensor。也可以自定义collare_fn函数,然后通过collate_fn参数指定,在自定义的函数中增加需要的操作。例如,将每个样本padding到当前batch的最大样本长度。任何想要对批量数据进行的操作都要定义在这个函数中。
pytroch的默认collate_fn实现:
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
5. pin_memory
pin_memory是指锁页内存,什么是锁页内存?
内存分为锁页和不锁页,锁页内存存的内容在任何情况下都不会与机器的虚拟内存(虚拟内存就是硬盘)进行交换。不锁页内存在主机内存不足时,数据会存放到虚拟内存。
如果pin_memory=True,那么生成的数据都会放在锁页内存上,此时将tensor拷贝到GPU的显存会更快。
6. 自动批量化/非批量化
dataloader默认返回批量的样本(batch_size默认为1)。当参数batch_size和batch_sample均为None时,会关闭自动批量化操作。此时会将采样的单个样本传给collate_fn函数。
参考
[1]pytorch官方文档 TORCH.UTILS.DATA部分