前言:前面的系列文章之第一篇已经基本上说明了DataSet类和DataLoader类的用法,但是鉴于DataLoader类中有一个参数collate_fn使用起来比较复杂,所以本次的第二篇文章还专门说一下这个函数的功能。第一篇文章请参考:
(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
collate_fn,中单词collate的含义是:核对,校勘,对照,整理。顾名思义,这就是一个对每一组样本数据进行一遍“核对和重新整理”,现在可能更好理解一些。
一、本次案例
本次为了更加方便的演示整个过程,假设有20组训练样本,输入的样本x是【1,2,3,4,,,,18,19,20】,
输出的标签y是 【100,200,300,,,,1800,1900,2000】
现在我不是用collate_fn参数,我要将数据随机打乱,并且batch_size等于3,简单的实现如下:
import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader
x=range(1,21,1)
y=range(100,2100,100)
class XYDataSet(Dataset):
def __init__(self,x,y):
self.x_list=x
self.y_list=y
assert len(self.x_list)==len(self.y_list)
def __len__(self):
return len(self.x_list)
def __getitem__(self,index):
x_one=self.x_list[index]
y_one=self.y_list[index]
return (x_one,y_one)
# 第一步:构造dataset
dataset=XYDataSet(x,y)
# 第二步:构造dataloader
dataloader=DataLoader(dataset,batch_size=3,shuffle=True)
# 第三步:对dataloader进行迭代
for epoch in range(2): # 只查看两个epoch
for x_train,y_train in dataloader:
print(x_train)
print(y_train)
print("-----------------------------------")
'''
tensor([ 3, 12, 16])
tensor([ 300, 1200, 1600])
-----------------------------------
tensor([13, 15, 9])
tensor([1300, 1500, 900])
-----------------------------------
tensor([ 8, 19, 7])
tensor([ 800, 1900, 700])
-----------------------------------
tensor([18, 1, 14])
tensor([1800, 100, 1400])
-----------------------------------
tensor([ 2, 5, 20])
tensor([ 200, 500, 2000])
-----------------------------------
tensor([11, 17, 6])
tensor([1100, 1700, 600])
-----------------------------------
tensor([10, 4])
tensor([1000, 400])
----------------------------------- # 这里是第一个epoch结束了,会进行一次混洗
tensor([14, 2, 3])
tensor([1400, 200, 300])
-----------------------------------
tensor([11, 1, 15])
tensor([1100, 100, 1500])
-----------------------------------
tensor([10, 6, 20])
tensor([1000, 600, 2000])
-----------------------------------
tensor([13, 5, 8])
tensor([1300, 500, 800])
-----------------------------------
tensor([18, 19, 12])
tensor([1800, 1900, 1200])
-----------------------------------
tensor([ 4, 16, 7])
tensor([ 400, 1600, 700])
-----------------------------------
tensor([ 9, 17])
tensor([ 900, 1700])
-----------------------------------
'''
但是现在有一个问题,我希望对于原来的样本数据重新你处理一下,将每一组样本的x加上0.5,将每一组样本的y加上50,然后重新组成样本,我当然可以这么做,即在定义DataSet的__getitem__里面去实现,只需要简单的更改__getitem__即可,如下:
def __getitem__(self,index):
x_one=self.x_list[index]+0.5 # 每一个x加上0.5
y_one=self.y_list[index]+50 # 每一个y加上50
return x_one,y_one
二、通过自定义collate_fn函数来实现
这里整个DataSet的实现完全不变,定义的函数如下:
import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.dataloader import default_collate # 导入这个函数,这个函数其实就是pytorch默认给这个collate_fn的默认实现
def collate_fn(batch):
"""
batch :是一个列表,列表的长度是 batch_size
列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
"""
new_batch=[]
for index in range(len(batch)):
x_=batch[index][0]+0.5 # 每一个样本x加上0.5
y_=batch[index][1]+50 # 没一个样本y加上50
new_batch.append((x_,y_)) # 将改变之后的x,y重新组成一个batch
return default_collate(new_batch)
# 第一步:构造dataset
dataset=XYDataSet(x,y)
# 第二步:构造dataloader,这里需要传递自定义的collate_fn函数
dataloader=DataLoader(dataset,batch_size=3,shuffle=True,collate_fn=collate_fn)
# 第三步:对dataloader进行迭代
for epoch in range(2): # 只查看两个epoch
for x_train,y_train in dataloader:
print(x_train)
print(y_train)
print("-----------------------------------")
'''
tensor([18.5000, 1.5000, 20.5000], dtype=torch.float64)
tensor([1850, 150, 2050])
-----------------------------------
tensor([19.5000, 14.5000, 4.5000], dtype=torch.float64)
tensor([1950, 1450, 450])
-----------------------------------
tensor([6.5000, 2.5000, 5.5000], dtype=torch.float64)
tensor([650, 250, 550])
-----------------------------------
tensor([17.5000, 9.5000, 15.5000], dtype=torch.float64)
tensor([1750, 950, 1550])
-----------------------------------
tensor([12.5000, 8.5000, 7.5000], dtype=torch.float64)
tensor([1250, 850, 750])
-----------------------------------
tensor([ 3.5000, 13.5000, 16.5000], dtype=torch.float64)
tensor([ 350, 1350, 1650])
-----------------------------------
tensor([11.5000, 10.5000], dtype=torch.float64)
tensor([1150, 1050])
-----------------------------------
tensor([ 2.5000, 7.5000, 18.5000], dtype=torch.float64)
tensor([ 250, 750, 1850])
-----------------------------------
tensor([14.5000, 13.5000, 9.5000], dtype=torch.float64)
tensor([1450, 1350, 950])
-----------------------------------
tensor([10.5000, 12.5000, 17.5000], dtype=torch.float64)
tensor([1050, 1250, 1750])
-----------------------------------
tensor([20.5000, 15.5000, 8.5000], dtype=torch.float64)
tensor([2050, 1550, 850])
-----------------------------------
tensor([ 6.5000, 11.5000, 19.5000], dtype=torch.float64)
tensor([ 650, 1150, 1950])
-----------------------------------
tensor([5.5000, 1.5000, 4.5000], dtype=torch.float64)
tensor([550, 150, 450])
-----------------------------------
tensor([16.5000, 3.5000], dtype=torch.float64)
tensor([1650, 350])
-----------------------------------
'''
三、collate_fn函数的一般定义格式
从上面的例子中可以更加清楚的理解“collate的校对、整理”的含义,这个函数自定义实现的时候有一个大致的模板:
from torch.utils.data.dataloader import default_collate # 导入这个函数
def collate_fn(batch):
"""
params:
batch :是一个列表,列表的长度是 batch_size
列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
returns:
整理之后的新的batch
"""
# 这一部分是对 batch 进行重新 “校对、整理”的代码
return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。