本文以torch.utils.data中的Dataset类为例进行说明
Dataset的作用是构建自定义的数据集,以方便使用Dataloader进行加载
语法
我们自定义的数据集需要继承自torch.util.data.Dataset抽象类,并重写相应的两个方法:
- len:返回数据集的大小。一般情况而言直接用 len(xxx) 进行实现即可
- getitem:使得 dataset[i] 能够返回数据集中的第i个样本,相应的需要传入一个索引i
原抽象类中相应的定义如下:
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
数据
假设我们在解决一个分类问题。那么,在训练集文件夹train中,我们可以这么给图片加上标签:
到时候就可以通过文件名的方式来判断某张图片对应的分类。
例子
我们构造一个FruitDataset来处理这些数据。首先实现init方法:
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = os.listdir(self.root_dir)
init方法一般会有两个基础的参数,一个是dir,用来表示数据集所在的目录;另一个则是transform,可以传入一些方法,以对图片进行处理(一般是进行数据增强)。
此外,在init方法中还会进行基础的数据读取,例如这里使用listdir来列出目录下的所有文件;而如果是表格形式的数据(如kaggle),那么则可以使用切片方法将标签与数据分离,方便后续的处理。
接下来是len方法,即返回数据集的长度。既然我们刚才已经读出了数据集目录下的所有文件,那么只要返回这个文件夹列表的长度即可:
def __len__(self):
return len(self.images)
最后则是getitem方法。getitem方法返回的是一个字典,表示相应数据所蕴含的其他信息,有了其他信息一个数据才能变成一个样本。在这里,“其他信息”就是图片所对应的标签,即要返回一个{‘image’:img, ‘label’:label}。习惯上,我们会把这个字典记做sample。
img可以通过imread方法读取图片实际内容得到,而label可以通过处理文件名获得:
def __getitem__(self,index):
# 通过路径与索引读图片
image_index = self.images[index]
img_path = os.path.join(self.root_dir, image_index)
img = io.imread(img_path)
# 通过文件名读标签
label = img_path.split('\\')[-1].split('.')[0]
# 组装成字典
sample = {'image':img,'label':label}
if self.transform:
sample = self.transform(sample)
return sample
注意这里的if self.transform也算是一种习惯上的用法,即如果传入了变换方法则进行变换后再返回。
Dataloader
我们通过dataloader来分析刚才构建的数据集。一般来说,训练集与测试集各会对应一个dataloader,这里为了演示方便起见就只拿我们刚才的训练集进行说明。
首先,实例化一个Dataset对象。在这里我们没有变换方法,则只需要传入数据所在的目录即可:
data = FruitDataset(r"data\train", transform=None)
dataset对象可以通过下标来访问其中的各个样本,比如:
print(data[0])
然后利用dataloader进行加载:
dataloader = DataLoader(data, batch_size=2, shuffle=True)
一般而言Dataloader需要传入三个参数:
- dataset:传入Dataset对象,表示需要加载的数据集
- batch_size:“批大小”,表示一次选取的一批中有几个样本。在这里bs为2,即每轮选取2个样本
- shuffle:是否需要将数据打乱。一般来说只需要打乱训练集即可,测试集并不需要打乱
查看dataloader的长度。总共有10张图,一批有2张,因此有5批,长度为5:
# 5
print(len(dataloader))
最后迭代整个数据集:
for i_batch, batch_data in enumerate(dataloader):
print(i_batch)
print(batch_data)
i_batch就是batch的编号,0、1、2、3、4;batch_data就是我们在数据集中定义的sample,在这里两个两个一组出现。
完整代码
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset, DataLoader
from skimage import io
import os
class FruitDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = os.listdir(self.root_dir)
def __len__(self):
return len(self.images)
def __getitem__(self,index):
image_index = self.images[index]
img_path = os.path.join(self.root_dir, image_index)
img = io.imread(img_path)
label = img_path.split('\\')[-1].split('.')[0]
sample = {'image':img,'label':label}
if self.transform:
sample = self.transform(sample)
return sample
data = FruitDataset(r"data\train", transform=None)
print(data[0])
dataloader = DataLoader(dataset=data, batch_size=2, shuffle=True)
print(len(dataloader))
for i_batch, batch_data in enumerate(dataloader):
print(i_batch)
print(batch_data)