前言:前面的系列文章之第一篇已经基本上说明了DataSet类和DataLoader类的用法,但是鉴于DataLoader类中有一个参数collate_fn使用起来比较复杂,所以本次的第二篇文章还专门说一下这个函数的功能。第一篇文章请参考:

(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

collate_fn,中单词collate的含义是:核对,校勘,对照,整理。顾名思义,这就是一个对每一组样本数据进行一遍“核对和重新整理”,现在可能更好理解一些。

一、本次案例

本次为了更加方便的演示整个过程,假设有20组训练样本,输入的样本x是【1,2,3,4,,,,18,19,20】,

输出的标签y是 【100,200,300,,,,1800,1900,2000】

现在我不是用collate_fn参数,我要将数据随机打乱,并且batch_size等于3,简单的实现如下:

import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader

x=range(1,21,1) 
y=range(100,2100,100)

class XYDataSet(Dataset):
    def __init__(self,x,y):
        self.x_list=x
        self.y_list=y
        assert len(self.x_list)==len(self.y_list)

    def __len__(self):
        return len(self.x_list)

    def __getitem__(self,index):
        x_one=self.x_list[index]
        y_one=self.y_list[index]
        return (x_one,y_one)

# 第一步:构造dataset
dataset=XYDataSet(x,y)
# 第二步:构造dataloader
dataloader=DataLoader(dataset,batch_size=3,shuffle=True)

# 第三步:对dataloader进行迭代
for epoch in range(2): # 只查看两个epoch
    for x_train,y_train in dataloader:
        print(x_train)
        print(y_train)
        print("-----------------------------------")
'''
tensor([ 3, 12, 16])
tensor([ 300, 1200, 1600])
-----------------------------------
tensor([13, 15,  9])
tensor([1300, 1500,  900])
-----------------------------------
tensor([ 8, 19,  7])
tensor([ 800, 1900,  700])
-----------------------------------
tensor([18,  1, 14])
tensor([1800,  100, 1400])
-----------------------------------
tensor([ 2,  5, 20])
tensor([ 200,  500, 2000])
-----------------------------------
tensor([11, 17,  6])
tensor([1100, 1700,  600])
-----------------------------------
tensor([10,  4])
tensor([1000,  400])
-----------------------------------    # 这里是第一个epoch结束了,会进行一次混洗
tensor([14,  2,  3])
tensor([1400,  200,  300])
-----------------------------------
tensor([11,  1, 15])
tensor([1100,  100, 1500])
-----------------------------------
tensor([10,  6, 20])
tensor([1000,  600, 2000])
-----------------------------------
tensor([13,  5,  8])
tensor([1300,  500,  800])
-----------------------------------
tensor([18, 19, 12])
tensor([1800, 1900, 1200])
-----------------------------------
tensor([ 4, 16,  7])
tensor([ 400, 1600,  700])
-----------------------------------
tensor([ 9, 17])
tensor([ 900, 1700])
-----------------------------------
'''

但是现在有一个问题,我希望对于原来的样本数据重新你处理一下,将每一组样本的x加上0.5,将每一组样本的y加上50,然后重新组成样本,我当然可以这么做,即在定义DataSet的__getitem__里面去实现,只需要简单的更改__getitem__即可,如下:

def __getitem__(self,index):
        x_one=self.x_list[index]+0.5  # 每一个x加上0.5
        y_one=self.y_list[index]+50   # 每一个y加上50
        return x_one,y_one

二、通过自定义collate_fn函数来实现

这里整个DataSet的实现完全不变,定义的函数如下:

import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.dataloader import default_collate  # 导入这个函数,这个函数其实就是pytorch默认给这个collate_fn的默认实现

def collate_fn(batch):
    """
    batch :是一个列表,列表的长度是 batch_size
           列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
    """
    new_batch=[]
    for index in range(len(batch)):
        x_=batch[index][0]+0.5  # 每一个样本x加上0.5
        y_=batch[index][1]+50   # 没一个样本y加上50

        new_batch.append((x_,y_))  # 将改变之后的x,y重新组成一个batch

    return default_collate(new_batch)
# 第一步:构造dataset
dataset=XYDataSet(x,y)
# 第二步:构造dataloader,这里需要传递自定义的collate_fn函数
dataloader=DataLoader(dataset,batch_size=3,shuffle=True,collate_fn=collate_fn)

# 第三步:对dataloader进行迭代
for epoch in range(2): # 只查看两个epoch
    for x_train,y_train in dataloader:
        print(x_train)
        print(y_train)
        print("-----------------------------------")
'''
tensor([18.5000,  1.5000, 20.5000], dtype=torch.float64)
tensor([1850,  150, 2050])
-----------------------------------
tensor([19.5000, 14.5000,  4.5000], dtype=torch.float64)
tensor([1950, 1450,  450])
-----------------------------------
tensor([6.5000, 2.5000, 5.5000], dtype=torch.float64)
tensor([650, 250, 550])
-----------------------------------
tensor([17.5000,  9.5000, 15.5000], dtype=torch.float64)
tensor([1750,  950, 1550])
-----------------------------------
tensor([12.5000,  8.5000,  7.5000], dtype=torch.float64)
tensor([1250,  850,  750])
-----------------------------------
tensor([ 3.5000, 13.5000, 16.5000], dtype=torch.float64)
tensor([ 350, 1350, 1650])
-----------------------------------
tensor([11.5000, 10.5000], dtype=torch.float64)
tensor([1150, 1050])
-----------------------------------
tensor([ 2.5000,  7.5000, 18.5000], dtype=torch.float64)
tensor([ 250,  750, 1850])
-----------------------------------
tensor([14.5000, 13.5000,  9.5000], dtype=torch.float64)
tensor([1450, 1350,  950])
-----------------------------------
tensor([10.5000, 12.5000, 17.5000], dtype=torch.float64)
tensor([1050, 1250, 1750])
-----------------------------------
tensor([20.5000, 15.5000,  8.5000], dtype=torch.float64)
tensor([2050, 1550,  850])
-----------------------------------
tensor([ 6.5000, 11.5000, 19.5000], dtype=torch.float64)
tensor([ 650, 1150, 1950])
-----------------------------------
tensor([5.5000, 1.5000, 4.5000], dtype=torch.float64)
tensor([550, 150, 450])
-----------------------------------
tensor([16.5000,  3.5000], dtype=torch.float64)
tensor([1650,  350])
-----------------------------------
'''

三、collate_fn函数的一般定义格式

从上面的例子中可以更加清楚的理解“collate的校对、整理”的含义,这个函数自定义实现的时候有一个大致的模板:

from torch.utils.data.dataloader import default_collate  # 导入这个函数

def collate_fn(batch):
    """
    params:
        batch :是一个列表,列表的长度是 batch_size
               列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
               大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
    returns:
        整理之后的新的batch
    """
     
    # 这一部分是对 batch 进行重新 “校对、整理”的代码

    return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。