前言:
pytorch虽然提供了torchvision.datasets包,封装了一些常用的数据集供我们很方便地调用,但我们经常需要训练自己的图像数据,构建并加载数据集往往是训练神经网络的第一步,本文将介绍如何构建加载自己的图像数据集,并用于神经网络输入。
一. 自定义图像数据集:
1. 数据集的文件结构:
train为数据集根目录,下一级为每个类别的文件夹,分别包含着若干张图像:
2. torch.utils.data.Dataset:
Dataset是表示数据集的抽象类,当我们自定义数据集时应继承Dataset类,并重写以下方法 :
1. __getitem__: 支持根据给定的key来获取数据样本。
2. __len__: 实现返回数据集的数据数量。
构建的自定义数据集类如下:
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import cv2
def get_images_and_labels(dir_path):
'''
从图像数据集的根目录dir_path下获取所有类别的图像名列表和对应的标签名列表
:param dir_path: 图像数据集的根目录
:return: images_list, labels_list
'''
dir_path = Path(dir_path)
classes = [] # 类别名列表
for category in dir_path.iterdir():
if category.is_dir():
classes.append(category.name)
images_list = [] # 文件名列表
labels_list = [] # 标签列表
for index, name in enumerate(classes):
class_path = dir_path / name
if not class_path.is_dir():
continue
for img_path in class_path.glob('*.jpg'):
images_list.append(str(img_path))
labels_list.append(int(index))
return images_list, labels_list
class MyDataset(Dataset):
def __init__(self, dir_path, transform=None):
self.dir_path = dir_path # 数据集根目录
self.transform = transform
self.images, self.labels = get_images_and_labels(self.dir_path)
def __len__(self):
# 返回数据集的数据数量
return len(self.images)
def __getitem__(self, index):
img_path = self.images[index]
label = self.labels[index]
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
sample = {'image': img, 'label': label}
if self.transform:
sample['image'] = self.transform(sample['image'])
return sample
二. 加载数据集:
1. torch.utils.data.DataLoader:
DataLoader是一个数据集加载器类,提供了很多方便的数据集操作,比如shuffle,batch,drop_last等,详细用法可参考文档。
if __name__ == '__main__':
train_dataset = MyDataset(r"C:\Users\admin\Desktop\set100-10\annotated_camera_images")
dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
for index, batch_data in enumerate(dataloader):
print(index, batch_data['image'].shape, batch_data['label'].shape)
运行的结果如下所示:
0 torch.Size([64, 224, 224, 3]) torch.Size([64])
1 torch.Size([64, 224, 224, 3]) torch.Size([64])
2 torch.Size([64, 224, 224, 3]) torch.Size([64])
3 torch.Size([64, 224, 224, 3]) torch.Size([64])
4 torch.Size([64, 224, 224, 3]) torch.Size([64])
5 torch.Size([64, 224, 224, 3]) torch.Size([64])
6 torch.Size([64, 224, 224, 3]) torch.Size([64])
7 torch.Size([64, 224, 224, 3]) torch.Size([64])
8 torch.Size([64, 224, 224, 3]) torch.Size([64])
9 torch.Size([64, 224, 224, 3]) torch.Size([64])
10 torch.Size([36, 224, 224, 3]) torch.Size([36])
Process finished with exit code 0
至此,就完成了图像数据集的构建与加载。
三. 对图像进行数据增强:
在自定义的数据集类中,我们设置了一个参数transform但没有用
1. torchvision.transforms:
torchvision.transforms中实现了许多常见的数据增强操作,比如Scale, Crop, Resize, Normalize, ColorJitter等等等等,可以浏览transforms.py查看所有操作。这里直接定义一个返回transforms.Compose的方法(有关transforms.Compose可以参考):
def get_transform_for_train():
transform_list = []
transform_list.append(transforms.ToPILImage())
transform_list.append(transforms.RandomHorizontalFlip(p=0.3))
transform_list.append(transforms.ColorJitter(0.1, 0.1, 0.1, 0.1))
transform_list.append(transforms.ToTensor())
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
return transforms.Compose(transform_list)
然后在实例化自定义数据集类的时候,就可以将transform作为参数传入:
if __name__ == '__main__':
train_dataset = MyDataset(r"C:\Users\admin\Desktop\set100-10\annotated_camera_images", get_transform_for_train())
dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
for index, batch_data in enumerate(dataloader):
print(index, batch_data['image'].shape, batch_data['label'].shape)
运行结果为:
0 torch.Size([64, 3, 224, 224]) torch.Size([64])
1 torch.Size([64, 3, 224, 224]) torch.Size([64])
2 torch.Size([64, 3, 224, 224]) torch.Size([64])
3 torch.Size([64, 3, 224, 224]) torch.Size([64])
4 torch.Size([64, 3, 224, 224]) torch.Size([64])
5 torch.Size([64, 3, 224, 224]) torch.Size([64])
6 torch.Size([64, 3, 224, 224]) torch.Size([64])
7 torch.Size([64, 3, 224, 224]) torch.Size([64])
8 torch.Size([64, 3, 224, 224]) torch.Size([64])
9 torch.Size([64, 3, 224, 224]) torch.Size([64])
10 torch.Size([36, 3, 224, 224]) torch.Size([36])
Process finished with exit code 0
后记:
到这里就结束。5月17号更新:
今天学习了torchvision.datasets.ImageFolder,利用这个类可以很方便定义一个通用的数据加载器,它所要求的图像排列结构如下,和文章一开始说的结构是一样的:
'''
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
'''
然后只需定义号一个transforms,就可以实现数据集的加载:
if __name__ == '__main__':
train_dataset = ImageFolder(root=r'C:\Users\admin\Desktop\set100-10\annotated_camera_images\train', transform=get_transform_for_train())
dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
for index, batch_data in enumerate(dataloader):
print(index, batch_data[0].shape, batch_data[1].shape)
输出结果为:
0 torch.Size([64, 3, 224, 224]) torch.Size([64])
1 torch.Size([64, 3, 224, 224]) torch.Size([64])
2 torch.Size([64, 3, 224, 224]) torch.Size([64])
3 torch.Size([64, 3, 224, 224]) torch.Size([64])
4 torch.Size([64, 3, 224, 224]) torch.Size([64])
5 torch.Size([64, 3, 224, 224]) torch.Size([64])
6 torch.Size([64, 3, 224, 224]) torch.Size([64])
7 torch.Size([64, 3, 224, 224]) torch.Size([64])
8 torch.Size([64, 3, 224, 224]) torch.Size([64])
9 torch.Size([64, 3, 224, 224]) torch.Size([64])
10 torch.Size([36, 3, 224, 224]) torch.Size([36])
Process finished with exit code 0
确实非常方便。。一般情况下用这个就够了