数据集组成
网络训练的第一步就是读取数据,关于输入图片如何读取,如何进行预处理,将会在本篇文章中进行演示。
首先需要了解的是,语义分割中图片和标签是分别保存的。以voc数据集为例,它有20个类别,加上背景总共21个类别。其中,JPEGImages文件夹下存放的是输入图片,它们都是JPG格式。每张图片都是R,G,B三通道,其像素值在0-255之间。
SegmentationClass文件夹下存放的是标签,它们都是PNG格式。每张标签都是单通道的,其像素值0-N之间,其中N为分类的类别数。至于为什么单通道的图片看起来还是彩色的,这其实是通过***伪色码***显示的,本质上还是单通道。特别需要注意的是,标签中的背景,也就是图中黑色的部分,它的像素值是255。
光有图片和标签是不够的,我们还不知道那些图片需要训练,那些图片需要验证。所以在voc数据集中还有一个文件是用来区分那些是训练用的图片,那些是预测用的图片。
它们都是以txt文件储存,点击进去,会发现里面全部都是图片的名称,它们都没有后缀。这些是制作这个数据集的作者为准备好用来训练的图片。
datasets的搭建
在pytorch中,训练模型需要将图片和标签读入对应的类当中,这个类就叫做dataset。我们读取自己的数据集的时候只需要重写这个类就可以了,特别的,我们自定义的这个类必须继承pytorch官方定义的Dataset这个父类。下图为自定义的voc dataset,它有很多类属性,例如self.root就是存放图片的根路径。每个属性的作用都在下面批注了注释,这里就不具体介绍了。
在自定义数据集中,我们肯定要告诉程序,我们的图片、标签存放在哪里,它应该如何读取,所以我们自定义了一个函数,它的名字叫set_files,我们在初始化的时候就执行了它。其中函数的功能如下图所示,通过这个函数,我们会得到图片和标签的根路径,它们分别会存放在self.image_dir和self.label_dir中,我们在这个类里面可以随时调用这个类属性。之前提到过,除了图片和标签,我们还需知道那些图片进行训练,那些图片进行预测。所以我们self.files中读txt文件,当我们为训练模式的时候,读取的是train.txt,当我们是验证模式的时候读取是val.txt。
前面我们只是得到了图片的路径,并将它们赋值给类变量,我们并没有对它们进行读取,所以我们需要一个函数来将它们读取,具体函数见下图。它传入一个index索引,通过这个索引,我们就可以从self.files中拿出我们需要训练的图片的名称,再根据之前的到的图片和标签的根路径,将名称与它们拼接,就可以得到一个完整的图片路径。我们通过Image.open函数打卡这张图片同时将它转换为一个数组,方便我们后续对它进行处理。最后返回数组形式的图片和标签。
两个重要的魔法方法
上面的操作是们自定义的,但是如果要实现Dataset的功能,我们就必须要重写两个方法__getitem__和__len__。其中__getitem__必须有index参数,因为这个参数控制着代码当前读取那一张图片。通过下面的代码可以发现,我们将这个index传给上面的讲到的函数,拿到具体的一张图片,并判断是否对他进行数据增强操作,同时也会统一数据的格式,最后一定要返回处理好的图片和标签。
__len__方法也是必须写的,因为,程序无法知道这个数据集有多少张图片,应该迭代多少次结束。所以我们要重写__len__方法,就是要告诉程序数据集的长度,它写起来也是非常简单。
数据增强
细心的小伙伴肯定发现__getitem__函数里面调用了self._augmentation()函数,这个函数的作用是对图片进行数据增强,由于我们已经得到了数组形式的图片和标签,那么这里对它进行数据增强已经非常简单了,因为已经有现成的库帮我们实现了这些功能,我们只用调用就好了。
完整代码
# Originally written by Kazuto Nakashima
# https://github.com/kazuto1011/deeplab-pytorch
import numpy as np
import os
from torch.utils.data import Dataset
import torch
from PIL import Image
from torchvision import transforms
import cv2
import random
class VOCDataset(Dataset):
def __init__(self, root, split='train',num_classes=21, base_size=None, augment=True,
crop_size=321, scale=True, flip=True, rotate=True, blur=True,):
super(VOCDataset, self).__init__()
self.root = root # 存放数据集的根路径
self.num_classes = num_classes # 数据集的类别总数
self.MEAN = [0.45734706, 0.43338275, 0.40058118] # 数据集的均值和方差
self.STD = [0.23965294, 0.23532275, 0.2398498]
self.crop_size = crop_size #裁剪图片的大小
self.scale = scale #是否进行scale
self.flip = flip #是否进行flip
self.rotate = rotate #是否进行rotate
self.blur = blur # 是否进行blur
self.base_size = base_size # 基础读入图片大小
self.augment = augment #是否进行数据增强
self.split = split # 拿到训练模式
self._set_files() # 调用函数,拿到所有训练 验证的图片名字
self.to_tensor = transforms.ToTensor() # 对图片进行归一化处理
self.normalize = transforms.Normalize(self.MEAN,self.STD)
def _set_files(self):
self.root = os.path.join(self.root, 'VOC2012') # VOC数据集的路径
self.image_dir = os.path.join(self.root, 'JPEGImages') # 图片的存放路径
self.label_dir = os.path.join(self.root, 'SegmentationClass') # 标签的存放路径
file_list = os.path.join(self.root, "ImageSets/Segmentation", self.split + ".txt")
# 训练或验证图片的名称txt文件
self.files = [line.rstrip() for line in tuple(open(file_list, "r"))] # 训练或验证图片的名称 放入列表
# 这里拿到的是对应的图片的名字 放在列表中
def _load_data(self, index):
image_id = self.files[index] # 根据索引取图片
image_path = os.path.join(self.image_dir, image_id + '.jpg') # 图片路径
label_path = os.path.join(self.label_dir, image_id + '.png') # 标签路径
# 将图片转成数组
image = np.asarray(Image.open(image_path), dtype=np.float32)
label = np.asarray(Image.open(label_path), dtype=np.int32)
return image, label
def __getitem__(self, index):
"__getitem__方法在自定义数据集的时候必须重写.index是输入图片的索引值"
"在这个函数里面可以对图片进行预处理,但是要返回处理好的图片"
image, label = self._load_data(index) # 拿到每一张图片和标签
if self.augment: # 判断是否进行数据增强
image, label = self._augmentation(image, label)
# 统一输入图片格式
label = torch.from_numpy(np.array(label, dtype=np.float32)).long()
image = Image.fromarray(np.uint8(image))
return self.normalize(self.to_tensor(image)), label # 归一化 将图片转换为tensor对象
def __len__(self):
"__len__方法在自定义数据集时候必须重写.返回数据集的长度"
return len(self.files)
#数据增强函数
def _augmentation(self, image, label):
h, w, _ = image.shape
if self.base_size:
if self.scale:
longside = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
else:
longside = self.base_size
h, w = (longside, int(1.0 * longside * w / h + 0.5)) if h > w else (
int(1.0 * longside * h / w + 0.5), longside)
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST)
h, w, _ = image.shape
# 旋转图片在(-10°和10°之间)
if self.rotate:
angle = random.randint(-10, 10)
center = (w / 2, h / 2)
rot_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
image = cv2.warpAffine(image, rot_matrix, (w, h),
flags=cv2.INTER_LINEAR)
label = cv2.warpAffine(label, rot_matrix, (w, h),
flags=cv2.INTER_NEAREST)
# 对不符合指定大小的图片进行裁剪
if self.crop_size:
pad_h = max(self.crop_size - h, 0)
pad_w = max(self.crop_size - w, 0)
pad_kwargs = {
"top": 0,
"bottom": pad_h,
"left": 0,
"right": pad_w,
"borderType": cv2.BORDER_CONSTANT, }
if pad_h > 0 or pad_w > 0:
image = cv2.copyMakeBorder(image, value=0, **pad_kwargs)
label = cv2.copyMakeBorder(label, value=0, **pad_kwargs)
# 对不符合大小的图片进行padding
h, w, _ = image.shape
start_h = random.randint(0, h - self.crop_size)
start_w = random.randint(0, w - self.crop_size)
end_h = start_h + self.crop_size
end_w = start_w + self.crop_size
image = image[start_h:end_h, start_w:end_w]
label = label[start_h:end_h, start_w:end_w]
# 随机反转
if self.flip:
if random.random() > 0.5:
image = np.fliplr(image).copy()
label = np.fliplr(label).copy()
# 给图片增加高斯噪音
if self.blur:
sigma = random.random()
ksize = int(3.3 * sigma)
ksize = ksize + 1 if ksize % 2 == 0 else ksize
image = cv2.GaussianBlur(image, (ksize, ksize), sigmaX=sigma, sigmaY=sigma,
borderType=cv2.BORDER_REFLECT_101)
return image, label