目录

  • 前言
  • xml文件解析
  • 构建Dataset
  • 数据增强
  • 构建dataloader


前言

从voc数据集中的xml中的信息可以看出其中的信息还是很多、很复杂的,为了后续方便使用,再这里首先对xml进行处理。

xml文件解析

如图,对于每一个xml文件,我们将它包含的boxes、labels以及difficulties提取出来并以字典的形式的保存。

nuscenes 数据集处理 数据集处理实例_json


相关代码实现:

#voc_labels为VOC数据集中20类目标的类别名称
voc_labels = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
              'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')

#创建label_map字典,用于存储类别和类别索引之间的映射关系。比如:{1:'aeroplane', 2:'bicycle',......}
label_map = {k: v + 1 for v, k in enumerate(voc_labels)}
#VOC数据集默认不含有20类目标中的其中一类的图片的类别为background,类别索引设置为0
label_map['background'] = 0

#将映射关系倒过来,{类别名称:类别索引}
rev_label_map = {v: k for k, v in label_map.items()}  # Inverse mapping

#解析xml文件,最终返回这张图片中所有目标的标注框及其类别信息,以及这个目标是否是一个difficult目标
def parse_annotation(annotation_path):
    #解析xml
    tree = ET.parse(annotation_path)
    root = tree.getroot()

    boxes = list()	#存储bbox
    labels = list()	#存储bbox对应的label
    difficulties = list()	#存储bbox对应的difficult信息
    
    #遍历xml文件中所有的object,前面说了,有多少个object就有多少个目标
    for object in root.iter('object'):
		#提取每个object的difficult、label、bbox信息
        difficult = int(object.find('difficult').text == '1')
        label = object.find('name').text.lower().strip()
        if label not in label_map:
            continue
        bbox = object.find('bndbox')
        xmin = int(bbox.find('xmin').text) - 1
        ymin = int(bbox.find('ymin').text) - 1
        xmax = int(bbox.find('xmax').text) - 1
        ymax = int(bbox.find('ymax').text) - 1
		#存储
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(label_map[label])
        difficulties.append(difficult)
	
    #返回包含图片标注信息的字典
    return {'boxes': boxes, 'labels': labels, 'difficulties': difficulties}

对于整个VOC数据集而言,我们还需要生成包含20类的类别标签json文件,训练集测试集图片列表json文件、与之对应的目标标签信息——训练集测试集目标标注json文件,其中图片列表json文件标标注json文件应有对应关系。因此,我们可以:

1.根据VOC的类别获取label_map,并生成相应json文件
2.通过ImageSets/Main/trainval.txt或ImageSets/Main/test.txt来获取训练集或测试集中图片id
3.根据id在Annotations文件夹下找到相应的xml文件,通过执行上面的parse_annotation函数解析xml文件,获取boxes、labels以及difficulties等信息
4.根据id在JPEGImages文件夹下找到相应的图片路径,生成图片列表json文件

代码实现:

"""python
    分别读取train和valid的图片和xml信息,创建用于训练和测试的json文件
"""
def create_data_lists(voc07_path, voc12_path, output_folder):
    """
    Create lists of images, the bounding boxes and labels of the objects in these images, and save these to file.
    :param voc07_path: path to the 'VOC2007' folder
    :param voc12_path: path to the 'VOC2012' folder
    :param output_folder: folder where the JSONs must be saved
    """
    
    #获取voc2007和voc2012数据集的绝对路径
    voc07_path = os.path.abspath(voc07_path)
    voc12_path = os.path.abspath(voc12_path)

    train_images = list()
    train_objects = list()
    n_objects = 0

    # Training data
    for path in [voc07_path, voc12_path]:

        # Find IDs of images in training data
        #获取训练所用的train和val数据的图片id
        with open(os.path.join(path, 'ImageSets/Main/trainval.txt')) as f:
            ids = f.read().splitlines()
		
        #根据图片id,解析图片的xml文件,获取标注信息
        for id in ids:
            # Parse annotation's XML file
            objects = parse_annotation(os.path.join(path, 'Annotations', id + '.xml'))
            if len(objects['boxes']) == 0:	#如果没有目标则跳过
                continue
            n_objects += len(objects)		#统计目标总数
            train_objects.append(objects)	#存储每张图片的标注信息到列表train_objects
            train_images.append(os.path.join(path, 'JPEGImages', id + '.jpg'))	#存储每张图片的路径到列表train_images,用于读取图片

    assert len(train_objects) == len(train_images)		#检查图片数量和标注信息量是否相等,相等才继续执行程序

    # Save to file
    #将训练数据的图片路径,标注信息,类别映射信息,分别保存为json文件
    with open(os.path.join(output_folder, 'TRAIN_images.json'), 'w') as j:
        json.dump(train_images, j)
    with open(os.path.join(output_folder, 'TRAIN_objects.json'), 'w') as j:
        json.dump(train_objects, j)
    with open(os.path.join(output_folder, 'label_map.json'), 'w') as j:
        json.dump(label_map, j)  # save label map too

    print('\nThere are %d training images containing a total of %d objects. Files have been saved to %s.' % (
        len(train_images), n_objects, os.path.abspath(output_folder)))

    
    #与Train data一样,目的是将测试数据的图片路径,标注信息,类别映射信息,分别保存为json文件,参考上面的注释理解
    # Test data
    test_images = list()
    test_objects = list()
    n_objects = 0

    # Find IDs of images in the test data
    with open(os.path.join(voc07_path, 'ImageSets/Main/val.txt')) as f:
        ids = f.read().splitlines()

    for id in ids:
        # Parse annotation's XML file
        objects = parse_annotation(os.path.join(voc07_path, 'Annotations', id + '.xml'))
        if len(objects) == 0:
            continue
        test_objects.append(objects)
        n_objects += len(objects)
        test_images.append(os.path.join(voc07_path, 'JPEGImages', id + '.jpg'))

    assert len(test_objects) == len(test_images)

    # Save to file
    with open(os.path.join(output_folder, 'TEST_images.json'), 'w') as j:
        json.dump(test_images, j)
    with open(os.path.join(output_folder, 'TEST_objects.json'), 'w') as j:
        json.dump(test_objects, j)

    print('\nThere are %d test images containing a total of %d objects. Files have been saved to %s.' % (
        len(test_images), n_objects, os.path.abspath(output_folder)))

构建Dataset

pytorch中所有的数据集均继承自torch.utils.data.Dataset,它们都需要实现了 getitemlen 两个接口,因此,实现一个数据集的核心也就是实现这两个接口。
代码实现:

"""python
    PascalVOCDataset具体实现过程
"""
import torch
from torch.utils.data import Dataset
import json
import os
from PIL import Image
from utils import transform


class PascalVOCDataset(Dataset):
    """
    A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
    """

    #初始化相关变量
    #读取images和objects标注信息
    def __init__(self, data_folder, split, keep_difficult=False):
        """
        :param data_folder: folder where data files are stored
        :param split: split, one of 'TRAIN' or 'TEST'
        :param keep_difficult: keep or discard objects that are considered difficult to detect?
        """
        self.split = split.upper()    #保证输入为纯大写字母,便于匹配{'TRAIN', 'TEST'}

        assert self.split in {'TRAIN', 'TEST'}

        self.data_folder = data_folder
        self.keep_difficult = keep_difficult

        # Read data files
        with open(os.path.join(data_folder, self.split + '_images.json'), 'r') as j:
            self.images = json.load(j)
        with open(os.path.join(data_folder, self.split + '_objects.json'), 'r') as j:
            self.objects = json.load(j)

        assert len(self.images) == len(self.objects)

    #循环读取image及对应objects
    #对读取的image及objects进行tranform操作(数据增广)
    #返回PIL格式图像,标注框,标注框对应的类别索引,对应的difficult标志(True or False)
    def __getitem__(self, i):
        # Read image
        #*需要注意,在pytorch中,图像的读取要使用Image.open()读取成PIL格式,不能使用opencv
        #*由于Image.open()读取的图片是四通道的(RGBA),因此需要.convert('RGB')转换为RGB通道
        image = Image.open(self.images[i], mode='r')
        image = image.convert('RGB')

        # Read objects in this image (bounding boxes, labels, difficulties)
        objects = self.objects[i]
        boxes = torch.FloatTensor(objects['boxes'])  # (n_objects, 4)
        labels = torch.LongTensor(objects['labels'])  # (n_objects)
        difficulties = torch.ByteTensor(objects['difficulties'])  # (n_objects)

        # Discard difficult objects, if desired
        #如果self.keep_difficult为False,即不保留difficult标志为True的目标
        #那么这里将对应的目标删去
        if not self.keep_difficult:
            boxes = boxes[1 - difficulties]
            labels = labels[1 - difficulties]
            difficulties = difficulties[1 - difficulties]

        # Apply transformations
        #对读取的图片应用transform
        image, boxes, labels, difficulties = transform(image, boxes, labels, difficulties, split=self.split)

        return image, boxes, labels, difficulties

    #获取图片的总数,用于计算batch数
    def __len__(self):
        return len(self.images)

    #我们知道,我们输入到网络中训练的数据通常是一个batch一起输入,而通过__getitem__我们只读取了一张图片及其objects信息
    #如何将读取的一张张图片及其object信息整合成batch的形式呢?
    #collate_fn就是做这个事情,
    #对于一个batch的images,collate_fn通过torch.stack()将其整合成4维tensor,对应的objects信息分别用一个list存储
    def collate_fn(self, batch):
        """
        Since each image may have a different number of objects, we need a collate function (to be passed to the DataLoader).
        This describes how to combine these tensors of different sizes. We use lists.
        Note: this need not be defined in this Class, can be standalone.
        :param batch: an iterable of N sets from __getitem__()
        :return: a tensor of images, lists of varying-size tensors of bounding boxes, labels, and difficulties
        """

        images = list()
        boxes = list()
        labels = list()
        difficulties = list()

        for b in batch:
            images.append(b[0])
            boxes.append(b[1])
            labels.append(b[2])
            difficulties.append(b[3])

        #(3,224,224) -> (N,3,224,224)
        images = torch.stack(images, dim=0)

        return images, boxes, labels, difficulties  # tensor (N, 3, 224, 224), 3 lists of N tensors each

数据增强

数据增强,对提升网络精度和泛化能力很有帮助。
这里使用了许多数据增强的操作,包括改变图片亮度,对比度,饱和度和色相、扩大目标、随机裁剪图片、翻转、resize、ToTensor、归一化等,这里暂时不做详细理解。

构建dataloader

有了dataset,接下来就可以创建dataloader了。pytorch做了很好的封装,具体调用比较简单。

"""python
    DataLoader
"""
#参数说明:
#在train时一般设置shufle=True打乱数据顺序,增强模型的鲁棒性
#num_worker表示读取数据时的线程数,一般根据自己设备配置确定(如果是windows系统,建议设默认值0,防止出错)
#pin_memory,在计算机内存充足的时候设置为True可以加快内存中的tensor转换到GPU的速度
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                           collate_fn=train_dataset.collate_fn, num_workers=workers,
                                           pin_memory=True)  # note that we're passing the collate function here