学习记录
began
数据集的构建
一般pytorch 的数据加载到模型遵循“三步走”的策略,操作顺序是这样的:
* 创建一个 Dataset 对象。必须实现__len__()、__getitem__()这两个方法,这里面会用到transform对数据集进行扩充。
* 创建一个 DataLoader 对象。它是对DataSet对象进行迭代的,一般不需要实现里面的其他方法了。
* 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练。
其中datasets分为两种情况:
* 常见的数据读取方式datasets
* 自定义的datasets
1、第一步,使用dataset获取数据及其对应的真实标签
1.1 常见数据读取方式
其方法调用如下:
dataset = datasets.ImageFolder(
# 图片存储的根目录,即各类别文件夹所在目录的上级目录
root=rootpath,
# 对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。transforms.ToTensor()将numpy的ndarray或PIL.Image读的图片转换成形状为(C,H, W)的Tensor格式
transform=transforms.ToTensor(),
# 对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1,2,···
target_transform=None,
# 表示数据集加载方式,通常默认加载方式即可。
loader= < function default_loader >,
# 获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)
is_valid_file=None
)
返回的dataset都有以下三种属性:
dataset.classes:用一个 list 保存类别名称
#['刘亦菲', '周杰伦', '彭于晏', '胡歌', '陈奕迅']
dataset.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
#{'刘亦菲': 0, '周杰伦': 1, '彭于晏': 2, '胡歌': 3, '陈奕迅': 4}
dataset.imgs:保存(img-path, classname) tuple的 list
#[('D:/PYprogram/pytorch/FaceData/facetrain\\刘亦菲\\000000.jpg', 0), ('D:/PYprogram/pytorch/FaceData/facetrain\\刘亦菲\\000001.jpg', 0)]
1.2 自定义的dataset
插食管喂饭版:
from torch.utils.data import Dataset
from PIL import Image
import os
class myDataset(Dataset):
def __init__(self, root_dir, label_dir):
"""
:param root_dir: 根目录
:param label_dir: 标签名称
"""
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path_list[idx] # 只获取了文件名
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每个图片的位置
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path_list)
root_dir = "D:/PYprogram/pytorch/FaceData/facetrain"
label_dir ="胡歌"
mydataset = myDataset(root_dir=root_dir,label_dir=label_dir)
-----------------------------------------------------------------------------------------
in >>>
mydataset[0]
out>>>
(<PIL.JpegImagePlugin.JpegImageFile image mode=L size=128x128 at 0x17A72CA1C08>, '胡歌')
2、第二步:使用DataLoader来按批次读入数据
简单来说,DataLoader就是数据的加载器,在训练时使用,用来把训练数据分成多个小组
data = DataLoader(
# 使用ImageFolder构建好的dataset
dataset=dataset,
# 样本是按“批”读入的,batch_size就是每次读入的样本数,每次读入16张图片
batch_size=16,
# 有多少个进程用于读取数据
num_workers=0,
# 是否将读入的数据打乱
shuffle=True,
# 对于样本最后一部分没有达到批次数的样本,使其不再参与训练
drop_last=True)
3、循环遍历这个 DataLoader 对象,直到所有数据抛出
遍历该对象有2种方法:
方法一:使用next(iter(data)),运行一次调用一次
images, labels = next(iter(data))
print(images.shape)
方法二:使用enumerate循环
for id, dataloader in enumerate(data):
images, labels = dataloader
print(images.shape)
end