RuntimeError: each element in list of batch should be of equal size
自己定义dataset类,返回需要返回的相应数据,发现报了以下错误

RuntimeError: each element in list of batch should be of equal size

百度一下说最直接的方法是吧batch_size的值改为1,报错解除,一试果然。但是我他妈是训练模型来的,不是为了仅仅改错。batch_size设为1还怎么训练模型,因此决定研究一下此错误。

Original Traceback (most recent call last):
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 81, in default_collate
    raise RuntimeError('each element in list of batch should be of equal size')

根据报错信息可以查找错误来源在collate.py源码,错误就出现在default_collate()函数中。百度发现此源码的defaul_collate函数是DataLoader类默认的处理batch的方法,如果在定义DataLoader时没有使用collate_fn参数指定函数,就会默认调用以下源码中的方法。如果你出现了上述报错,应该就是此函数中出现了倒数第四行的错误

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

此函数功能就是传入一个batch数据元组,元组中是每个数据是你定义的dataset类中__getitem__()方法返回的内容,元组长度就是你的batch_size设置的大小。但是DataLoader类中最终返回的可迭代对象的一个字段是将batch_size大小样本的相应字段拼接到一起得到的。因此默认调用此方法时,第一次会进入倒数第二行语句return [default_collate(samples) for samples in transposed]将batch元组通过zip函数生成可迭代对象。然后通过迭代取出相同字段递归重新传入default_collate()函数中,此时取出第一个字段判断数据类型在以上所列类型中,则可正确返回dateset内容。
如果batch数据是按以上顺序进行处理,则不会出现以上错误。如果进行第二次递归之后元素的数据还不在所列数据类型中,则依然会进入下一次也就是第三次递归,此时就算能正常返回数据也不符合我们要求,而且报错一般就是出现在第三次递归及之后。因此想要解决此错误,需要仔细检查自己定义dataset类返回字段的数据类型。也可以在defaule_collate()方法中输出处理前后batch内容,查看函数具体处理流程,以助力自己查找返回字段数据类型的错误。
友情提示:不要在源码文件中更改defaule_collate()方法,可以把此代码copy出来,定义一个自己的collate_fn()函数并在实例化DataLoader类时指定自己定义的collate_fn函数。
祝各位早日解决bug,跑通模型!