学习记录
began

数据集的构建

一般pytorch 的数据加载到模型遵循“三步走”的策略,操作顺序是这样的:

* 创建一个 Dataset 对象。必须实现__len__()、__getitem__()这两个方法,这里面会用到transform对数据集进行扩充。
* 创建一个 DataLoader 对象。它是对DataSet对象进行迭代的,一般不需要实现里面的其他方法了。
* 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练。

其中datasets分为两种情况:

* 常见的数据读取方式datasets
* 自定义的datasets

1、第一步,使用dataset获取数据及其对应的真实标签

1.1 常见数据读取方式

其方法调用如下:

dataset = datasets.ImageFolder(
    # 图片存储的根目录,即各类别文件夹所在目录的上级目录
    root=rootpath,
    # 对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。transforms.ToTensor()将numpy的ndarray或PIL.Image读的图片转换成形状为(C,H, W)的Tensor格式
    transform=transforms.ToTensor(),
    # 对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1,2,···
    target_transform=None,
    # 表示数据集加载方式,通常默认加载方式即可。
    loader= < function default_loader >,
    # 获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)
    is_valid_file=None
)

返回的dataset都有以下三种属性:

dataset.classes:用一个 list 保存类别名称 
#['刘亦菲', '周杰伦', '彭于晏', '胡歌', '陈奕迅']
dataset.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应 
#{'刘亦菲': 0, '周杰伦': 1, '彭于晏': 2, '胡歌': 3, '陈奕迅': 4}
dataset.imgs:保存(img-path, classname) tuple的 list 
#[('D:/PYprogram/pytorch/FaceData/facetrain\\刘亦菲\\000000.jpg', 0), ('D:/PYprogram/pytorch/FaceData/facetrain\\刘亦菲\\000001.jpg', 0)]

1.2 自定义的dataset

插食管喂饭版:

from torch.utils.data import Dataset
from PIL import Image
import os


class myDataset(Dataset):
    def __init__(self, root_dir, label_dir):
        """
        :param root_dir: 根目录
        :param label_dir: 标签名称
        """
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path_list = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path_list[idx]  # 只获取了文件名
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 每个图片的位置
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path_list)


root_dir = "D:/PYprogram/pytorch/FaceData/facetrain"
label_dir ="胡歌"
mydataset = myDataset(root_dir=root_dir,label_dir=label_dir)
-----------------------------------------------------------------------------------------
in >>>
mydataset[0]
out>>>
(<PIL.JpegImagePlugin.JpegImageFile image mode=L size=128x128 at 0x17A72CA1C08>, '胡歌')

2、第二步:使用DataLoader来按批次读入数据

简单来说,DataLoader就是数据的加载器,在训练时使用,用来把训练数据分成多个小组

data = DataLoader(
    # 使用ImageFolder构建好的dataset
    dataset=dataset,
    # 样本是按“批”读入的,batch_size就是每次读入的样本数,每次读入16张图片
    batch_size=16,
    # 有多少个进程用于读取数据
    num_workers=0,
    # 是否将读入的数据打乱
    shuffle=True,
    # 对于样本最后一部分没有达到批次数的样本,使其不再参与训练
    drop_last=True)

3、循环遍历这个 DataLoader 对象,直到所有数据抛出

遍历该对象有2种方法:

方法一:使用next(iter(data)),运行一次调用一次

images, labels = next(iter(data))
print(images.shape)

方法二:使用enumerate循环

for id, dataloader in enumerate(data):
    images, labels = dataloader
    print(images.shape)

end