(深度学习)构造属于你自己的Pytorch数据集
1.综述
2.实现原理
3.代码细节
4.详细代码
综述
Pytorch可以说是一个非常便利的深度学习库,它甚至在torchvision.datasets
中拥有许多一步到位完成数据集下载、解析、读取的类——然鹅,这样也就养成了我们懒惰依赖的心理。当我们需要用到torchvision.datasets
中不曾拥有的数据集时,我们可能就会不知所措。
这篇文章中,我将以CIFAR-10数据集为例(虽然有torchvision.datasets.CIFAR10
了),摆脱对torchvision.datasets
的依赖,构建一个自己的数据集。
在开始之前,首先你要有CIFAR-10数据集,直接去官网上下载可能较慢(再次感谢我国著名建筑师方斌新院士 )=
数据集解压后目录情况如下:
实现原理
首先,torch.utils.data.DataLoader
不仅生成迭代数据非常方便,而且它也是经过优化的,效率十分之高(肯定比我们自己写一个要高多了),因此我们最好不要舍弃。
因此,我们的目标是根据CIFAR-10数据集构造一个Dataset的子类,使之能够作为torch.utils.data.DataLoader
的参数,从而使数据集能被我们用于生成迭代数据进行训练:
cifar10 = MyCIFAR10.MyCIFAR10('./data/cifar-10-batches-py', train=True)
train_loader = torch.utils.data.DataLoader(dataset=cifar10, batch_size=batch_size, shuffle=True)
要构造Dataset的子类,就必须要实现两个方法:
- _getitem_(self, index):根据index来返回数据集中标号为index的元素及其标签。
- _len_(self):返回数据集的长度。
因此,实质上我们主要是要通过__init__初始化之时读取数据集,再实现这两个函数便轻而易举。
代码细节
- _init_:
- root是存放解压后的数据集的根目录,根据上图我这里是
'./data/cifar-10-batches-py'
。 - X的类型是numpy数组,Y的类型是List;由于X作为数据要送入网络中,因此最后需要将其累加值从numpy数组转为Tensor。
def __init__(self, root, train=True, transform=None, target_transform=None):
super(MyCIFAR10, self).__init__()
self.transform = transform
self.target_transform = target_transform
self.imgs = None
self.labels = []
# 根据CIFAR-10官网上下载的数据,训练集分为5个batch文件,每个里有10000张32*32的图片;测试集只有1个batch文件,里面有10000张32*32的图片
train_lists = ['data_batch_1',
'data_batch_2',
'data_batch_3',
'data_batch_4',
'data_batch_5']
test_lists = ['test_batch']
# 根据train是否为True来选择测试集或训练集
if train:
lists = train_lists
else:
lists = test_lists
# 读取数据集,构造类中的图像集和标签
for list in lists:
filename = os.path.join(root, list)
with open(filename, 'rb') as f: # 这里需要'rb' + 'latin1'才能读取
datadict = pickle.load(f, encoding='latin1')
X = datadict['data'].reshape(-1, 3, 32, 32)
Y = datadict['labels']
if self.imgs is None:
self.imgs = np.vstack(X).reshape(-1, 3, 32, 32)
else:
self.imgs = np.vstack((self.imgs, X)).reshape(-1, 3, 32, 32)
self.labels = self.labels + Y
self.imgs = torch.from_numpy(self.imgs).type(torch.FloatTensor) # 最后需要将numpy数组转为Tensor
- _getitem_:
较为简单,直接给出:
def __getitem__(self, index):
img, label = self.imgs[index], self.labels[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.target_transform(label)
return img, label
- _len_:
极其简单,直接给出:
def __len__(self):
return len(self.imgs)
详细代码
class MyCIFAR10(Dataset):
"""
根据CIFAR-10定义的个人数据集类
继承自Dataset类,因此能够被torch.utils.data.DataLoader使用,从而更高效地在训练和测试中迭代
"""
def __init__(self, root, train=True, transform=None, target_transform=None):
super(MyCIFAR10, self).__init__()
self.transform = transform
self.target_transform = target_transform
self.imgs = None
self.labels = []
# 根据CIFAR-10官网上下载的数据,训练集分为5个batch文件,每个里有10000张32*32的图片;测试集只有1个batch文件,里面有10000张32*32的图片
train_lists = ['data_batch_1',
'data_batch_2',
'data_batch_3',
'data_batch_4',
'data_batch_5']
test_lists = ['test_batch']
# 根据train是否为True来选择测试集或训练集
if train:
lists = train_lists
else:
lists = test_lists
# 读取数据集,构造类中的图像集和标签
for list in lists:
filename = os.path.join(root, list)
with open(filename, 'rb') as f: # 这里需要'rb' + 'latin1'才能读取
datadict = pickle.load(f, encoding='latin1')
X = datadict['data'].reshape(-1, 3, 32, 32)
Y = datadict['labels']
if self.imgs is None:
self.imgs = np.vstack(X).reshape(-1, 3, 32, 32)
else:
self.imgs = np.vstack((self.imgs, X)).reshape(-1, 3, 32, 32)
self.labels = self.labels + Y
self.imgs = torch.from_numpy(self.imgs).type(torch.FloatTensor) # 最后需要将numpy数组转为Tensor
# 继承的Dataset类需要实现两个方法之一:__getitem__(self, index)
def __getitem__(self, index):
img, label = self.imgs[index], self.labels[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.target_transform(label)
return img, label
# 继承的Dataset类需要实现两个方法之一:__len__(self)
def __len__(self):
return len(self.imgs)