,表示需加载的源数据集。dataset有两种类型:“map-style dataset” 与 “iterable-style dataset”, map-style dataset可以理解为“每一样本值都可以通过一个索引键获取”, iterable-style dataset可以理解为“每一条样本值顺序存储在容器中,没有索引键”。参数sampler
用于控制数据集的加载顺序,这两个参数仅适用于map-style dataset,如果指定shuffle = True
, 则自动生成RandomSampler,如果指定shuffle = False
自动生成 SequentialSampler, 也可以通过参数sampler
在实践中经常使用,训练阶段“DataLoader、TensorDataset、RandomSampler”的组合, 推断阶段使用“DataLoader、TensorDataset、 SequentialSampler”的组合。
协议的数据集和类,称为“map-style dataset”。其中“torch.utils.data.Dataset”是典型的“map-style”类型, 如果需要自定义map-style 数据集类,应该继承torch.utils.data.Dataset, 并重实现__getitem__
。 Dataset的部分源码如下:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
functions: Dict[str, Callable] = {}
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
def __getattr__(self, attribute_name):
if attribute_name in Dataset.functions:
function = functools.partial(Dataset.functions[attribute_name], self)
return function
raise AttributeError
def register_function(cls, function_name, function):
cls.functions[function_name] = function
Dataset是抽象基类,如果不想自定义map-style, 通常使用TensorDataset类, 这是常用的一种map-style数据集类,其源码如下:
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
*tensors (Tensor): tensors that have the same size of the first dimension.
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
self.tensors = tensors
def __getitem__(self, index):
""" 本质是依次对每个传入的tensor,在第一个维度根据指定索引键取值,然后将所有值以元组组装起来。 """
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
协议的数据集合类是iterable-style dataset。当不能通过索引取值,或者通过索引取值开销很大时,经常采用这种形式的数据集类型。如果需要自定义iterable-style dataset, 可以继承IterableDataset类,并实现__iter__
。 IterableDataset的源码如下所示:
class IterableDataset(Dataset[T_co], metaclass=_DataPipeMeta):
r"""An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it.
Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite :meth:`__iter__`, which would return an
iterator of samples in this dataset.
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
iterator. When :attr:`num_workers > 0`, each worker process will have a
different copy of the dataset object, so it is often desired to configure
each copy independently to avoid having duplicate data returned from the
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
process, returns information about the worker. It can be used in either the
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
:attr:`worker_init_fn` option to modify each copy's behavior.
functions: Dict[str, Callable] = {}
reduce_ex_hook: Optional[Callable] = None
getstate_hook: Optional[Callable] = None
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
def __getattr__(self, attribute_name):
if attribute_name in IterableDataset.functions:
function = functools.partial(IterableDataset.functions[attribute_name], self)
return function
raise AttributeError
def __getstate__(self):
if IterableDataset.getstate_hook is not None:
return IterableDataset.getstate_hook(self)
return self.__dict__
def __reduce_ex__(self, *args, **kwargs):
if IterableDataset.reduce_ex_hook is not None:
return IterableDataset.reduce_ex_hook(self)
except NotImplementedError:
return super().__reduce_ex__(*args, **kwargs)
对于iterable-style dataset,样本数据是依次输出,输出的顺序由__iter__
方法生成源数据集合索引键序列,注意Sampler子类只能应用于map-style dataset。其源码如下所示:
class Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
def __init__(self, data_source: Optional[Sized]) -> None:
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
data_source (Dataset): dataset to sample from
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
is supposed to be specified only when `replacement` is ``True``.
generator (Generator): Generator used in sampling.
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError("replacement should be a boolean value, but got "
if self._num_samples is not None and not replacement:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator = self.generator
if self.replacement: # 重复采样
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else: # 非重复采样
yield from torch.randperm(n, generator=generator).tolist()
def __len__(self) -> int:
return self.num_samples
对于map-style dataset,batching相当于以下逻辑:
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
对于iterable-style dataset, batcing相当于以下逻辑:
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
- 在
处新增一个batch维度。 - 将numpy array,int, float等类型转换成tensor。
- 保留数据结构,对于字典、列表、元组对象,保留容器结构,但会将值转换成tensor。
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
if isinstance(elem, tuple):
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
return elem_type([default_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
class DataLoader(Generic[T_co]):
Data loader. Combines a dataset and a sampler, and provides an iterable over
the given dataset.
The :class:`~torch.utils.data.DataLoader` supports both map-style and
iterable-style datasets with single- or multi-process loading, customizing
loading order and optional automatic batching (collation) and memory pinning.
See :py:mod:`torch.utils.data` documentation page for more details.
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
and :attr:`drop_last`.
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
into CUDA pinned memory before returning them. If your data elements
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: ``0``)
worker_init_fn (callable, optional): If not ``None``, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: ``None``)
generator (torch.Generator, optional): If not ``None``, this RNG will be used
by RandomSampler to generate random indexes and multiprocessing to generate
`base_seed` for workers. (default: ``None``)
prefetch_factor (int, optional, keyword-only arg): Number of samples loaded
in advance by each worker. ``2`` means there will be a total of
2 * num_workers samples prefetched across all workers. (default: ``2``)
persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
the worker processes after a dataset has been consumed once. This allows to
maintain the workers `Dataset` instances alive. (default: ``False``)
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
cannot be an unpicklable object, e.g., a lambda function. See
:ref:`multiprocessing-best-practices` on more details related
to multiprocessing in PyTorch.
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
rounding depending on :attr:`drop_last`, regardless of multi-process loading
configurations. This represents the best guess PyTorch can make because PyTorch
trusts user :attr:`dataset` code in correctly handling multi-process
loading to avoid duplicate data.
However, if sharding results in multiple workers having incomplete last batches,
this estimate can still be inaccurate, because (1) an otherwise complete batch can
be broken into multiple ones and (2) more than one batch worth of samples can be
dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
cases in general.
See `Dataset Types`_ for more details on these two types of datasets and how
:class:`~torch.utils.data.IterableDataset` interacts with
`Multi-process data loading`_.
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
:ref:`data-loading-randomness` notes for random seed related questions.
dataset: Dataset[T_co]
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
sampler: Union[Sampler, Iterable]
prefetch_factor: int
_iterator : Optional['_BaseDataLoaderIter']
__initialized = False
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '
'use num_workers=0 to disable multiprocessing.')
if timeout < 0:
raise ValueError('timeout option should be non-negative')
if num_workers == 0 and prefetch_factor != 2:
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
'let num_workers > 0 to enable multiprocessing.')
assert prefetch_factor > 0
if persistent_workers and num_workers == 0:
raise ValueError('persistent_workers option needs num_workers > 0')
self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
# Arg-check dataset related before checking samplers because we want to
# tell users that iterable-style datasets are incompatible with custom
# samplers first, so that they don't learn that this combo doesn't work
# after spending time fixing the custom sampler errors.
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# NOTE [ Custom Samplers and IterableDataset ]
# `IterableDataset` does not support custom `batch_sampler` or
# `sampler` since the key is irrelevant (unless we support
# generator-style dataset one day...).
# For `sampler`, we always create a dummy sampler. This is an
# infinite sampler even when the dataset may have an implemented
# finite `__len__` because in multi-process data loading, naive
# settings will return duplicated data (which may be desired), and
# thus using a sampler with length matching that of dataset will
# cause data lost (you may have duplicates of the first couple
# batches, but never see anything afterwards). Therefore,
# `Iterabledataset` always uses an infinite sampler, an instance of
# `_InfiniteConstantSampler` defined above.
# A custom `batch_sampler` essentially only controls the batch size.
# However, it is unclear how useful it would be since an iterable-style
# dataset can handle that within itself. Moreover, it is pointless
# in multi-process data loading as the assignment order of batches
# to workers is an implementation detail so users can not control
# how to batchify each worker's iterable. Thus, we disable this
# option. If this turns out to be useful in future, we can re-enable
# this, and support custom samplers that specify the assignments to
# specific workers.
if shuffle is not False:
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"shuffle option, but got shuffle={}".format(shuffle))
elif sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"sampler option, but got sampler={}".format(sampler))
elif batch_sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"batch_sampler option, but got batch_sampler={}".format(batch_sampler))
self._dataset_kind = _DatasetKind.Map
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
if batch_sampler is not None:
# auto_collation with custom batch_sampler
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
batch_size = None
drop_last = False
elif batch_size is None:
# no auto_collation
if drop_last:
raise ValueError('batch_size=None option disables auto-batching '
'and is mutually exclusive with drop_last')
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
self.generator = generator
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
collate_fn = _utils.collate.default_convert
self.collate_fn = collate_fn
self.persistent_workers = persistent_workers
self.__initialized = True
self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
self._iterator = None
torch.set_vital('Dataloader', 'enabled', 'True') # type: ignore[attr-defined]
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
return self._iterator
return self._get_iterator()
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
return _MultiProcessingDataLoaderIter(self)
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
return self.sampler
def __len__(self) -> int:
if self._dataset_kind == _DatasetKind.Iterable:
# NOTE [ IterableDataset and __len__ ]
# For `IterableDataset`, `__len__` could be inaccurate when one naively
# does multi-processing data loading, since the samples will be duplicated.
# However, no real use case should be actually using that behavior, so
# it should count as a user error. We should generally trust user
# code to do the proper thing (e.g., configure each replica differently
# in `__iter__`), and give us the correct `__len__` if they choose to
# implement it (this will still throw if the dataset does not implement
# a `__len__`).
# To provide a further warning, we track if `__len__` was called on the
# `DataLoader`, save the returned value in `self._len_called`, and warn
# if the iterator ends up yielding more than this number of samples.
# Cannot statically verify that dataset is Sized
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler
from math import ceil
if self.drop_last:
length = length // self.batch_size
length = ceil(length / self.batch_size)
return length
return len(self._index_sampler)