前言:

pytorch虽然提供了torchvision.datasets包,封装了一些常用的数据集供我们很方便地调用,但我们经常需要训练自己的图像数据,构建并加载数据集往往是训练神经网络的第一步,本文将介绍如何构建加载自己的图像数据集,并用于神经网络输入。

一. 自定义图像数据集:

1. 数据集的文件结构:

train为数据集根目录,下一级为每个类别的文件夹,分别包含着若干张图像:

pytorch加载npy数据训练模型 pytorch怎么加载自己的数据集_加载

2. torch.utils.data.Dataset:

Dataset是表示数据集的抽象类,当我们自定义数据集时应继承Dataset类,并重写以下方法 :

1. __getitem__: 支持根据给定的key来获取数据样本。

2. __len__: 实现返回数据集的数据数量。

构建的自定义数据集类如下:

import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import cv2


def get_images_and_labels(dir_path):
    '''
    从图像数据集的根目录dir_path下获取所有类别的图像名列表和对应的标签名列表
    :param dir_path: 图像数据集的根目录
    :return: images_list, labels_list
    '''
    dir_path = Path(dir_path)
    classes = []  # 类别名列表

    for category in dir_path.iterdir():
        if category.is_dir():
            classes.append(category.name)
    images_list = []  # 文件名列表
    labels_list = []  # 标签列表

    for index, name in enumerate(classes):
        class_path = dir_path / name
        if not class_path.is_dir():
            continue
        for img_path in class_path.glob('*.jpg'):
            images_list.append(str(img_path))
            labels_list.append(int(index))
    return images_list, labels_list


class MyDataset(Dataset):
    def __init__(self, dir_path, transform=None):
        self.dir_path = dir_path    # 数据集根目录
        self.transform = transform
        self.images, self.labels = get_images_and_labels(self.dir_path)

    def __len__(self):
        # 返回数据集的数据数量
        return len(self.images)

    def __getitem__(self, index):
        img_path = self.images[index]
        label = self.labels[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        sample = {'image': img, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

二. 加载数据集:

1. torch.utils.data.DataLoader:

DataLoader是一个数据集加载器类,提供了很多方便的数据集操作,比如shuffle,batch,drop_last等,详细用法可参考文档。

if __name__ == '__main__':
    train_dataset = MyDataset(r"C:\Users\admin\Desktop\set100-10\annotated_camera_images")
    dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    for index, batch_data in enumerate(dataloader):
        print(index, batch_data['image'].shape, batch_data['label'].shape)

运行的结果如下所示:

0 torch.Size([64, 224, 224, 3]) torch.Size([64])
1 torch.Size([64, 224, 224, 3]) torch.Size([64])
2 torch.Size([64, 224, 224, 3]) torch.Size([64])
3 torch.Size([64, 224, 224, 3]) torch.Size([64])
4 torch.Size([64, 224, 224, 3]) torch.Size([64])
5 torch.Size([64, 224, 224, 3]) torch.Size([64])
6 torch.Size([64, 224, 224, 3]) torch.Size([64])
7 torch.Size([64, 224, 224, 3]) torch.Size([64])
8 torch.Size([64, 224, 224, 3]) torch.Size([64])
9 torch.Size([64, 224, 224, 3]) torch.Size([64])
10 torch.Size([36, 224, 224, 3]) torch.Size([36])

Process finished with exit code 0

至此,就完成了图像数据集的构建与加载。

三. 对图像进行数据增强:

在自定义的数据集类中,我们设置了一个参数transform但没有用

1. torchvision.transforms:

torchvision.transforms中实现了许多常见的数据增强操作,比如Scale, Crop, Resize, Normalize, ColorJitter等等等等,可以浏览transforms.py查看所有操作。这里直接定义一个返回transforms.Compose的方法(有关transforms.Compose可以参考):

def get_transform_for_train():
    transform_list = []
    transform_list.append(transforms.ToPILImage())
    transform_list.append(transforms.RandomHorizontalFlip(p=0.3))
    transform_list.append(transforms.ColorJitter(0.1, 0.1, 0.1, 0.1))
    transform_list.append(transforms.ToTensor())
    transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    return transforms.Compose(transform_list)

然后在实例化自定义数据集类的时候,就可以将transform作为参数传入:

if __name__ == '__main__':
    train_dataset = MyDataset(r"C:\Users\admin\Desktop\set100-10\annotated_camera_images", get_transform_for_train())
    dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    for index, batch_data in enumerate(dataloader):
        print(index, batch_data['image'].shape, batch_data['label'].shape)

运行结果为:

0 torch.Size([64, 3, 224, 224]) torch.Size([64])
1 torch.Size([64, 3, 224, 224]) torch.Size([64])
2 torch.Size([64, 3, 224, 224]) torch.Size([64])
3 torch.Size([64, 3, 224, 224]) torch.Size([64])
4 torch.Size([64, 3, 224, 224]) torch.Size([64])
5 torch.Size([64, 3, 224, 224]) torch.Size([64])
6 torch.Size([64, 3, 224, 224]) torch.Size([64])
7 torch.Size([64, 3, 224, 224]) torch.Size([64])
8 torch.Size([64, 3, 224, 224]) torch.Size([64])
9 torch.Size([64, 3, 224, 224]) torch.Size([64])
10 torch.Size([36, 3, 224, 224]) torch.Size([36])

Process finished with exit code 0

后记:

到这里就结束。5月17号更新:

今天学习了torchvision.datasets.ImageFolder,利用这个类可以很方便定义一个通用的数据加载器,它所要求的图像排列结构如下,和文章一开始说的结构是一样的:

'''
    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png

    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png
'''

然后只需定义号一个transforms,就可以实现数据集的加载:

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'C:\Users\admin\Desktop\set100-10\annotated_camera_images\train', transform=get_transform_for_train())
    dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    for index, batch_data in enumerate(dataloader):
        print(index, batch_data[0].shape, batch_data[1].shape)

输出结果为:

0 torch.Size([64, 3, 224, 224]) torch.Size([64])
1 torch.Size([64, 3, 224, 224]) torch.Size([64])
2 torch.Size([64, 3, 224, 224]) torch.Size([64])
3 torch.Size([64, 3, 224, 224]) torch.Size([64])
4 torch.Size([64, 3, 224, 224]) torch.Size([64])
5 torch.Size([64, 3, 224, 224]) torch.Size([64])
6 torch.Size([64, 3, 224, 224]) torch.Size([64])
7 torch.Size([64, 3, 224, 224]) torch.Size([64])
8 torch.Size([64, 3, 224, 224]) torch.Size([64])
9 torch.Size([64, 3, 224, 224]) torch.Size([64])
10 torch.Size([36, 3, 224, 224]) torch.Size([36])

Process finished with exit code 0

确实非常方便。。一般情况下用这个就够了