目录
- 1. torch.utils.data.Dataset
- 2. torch.utils.data.DataLoader
- 3. transforms
- 3.1 pytorch官方API
- 3.1.1 transforms.Compose
- 3.1.2 transforms.ToTensor
- 3.2 自定义transforms
在pytorch中,提供了一种十分方便的数据读取机制,即,使用torch.utils.data.Dataset
与torch.utils.data.DataLoader
组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个batch数据,并能在输出时对数据进行相应的预处理或数据增强等操作。
1. torch.utils.data.Dataset
torch.utils.data.Dataset
是代表自定义数据集方法的类,用户可以通过继承该类来自定义自己的数据集类,在继承时要求用户重载__len__()
和__getitem__()
这两个魔法方法。
-
__len__()
:返回的是数据集的大小。我们构建的数据集是一个对象,而数据集不像序列类型(列表、元组、字符串)那样可以直接用len()
来获取序列的长度,魔法方法__len__()
的目的就是方便像序列那样直接获取对象的长度。如果A
是一个类,a
是类A
的实例化对象,当A
中定义了魔法方法__len__()
,len(a)
则返回对象的大小。 -
__getitem__()
:实现索引数据集中的某一个数据。我们知道,序列可以通过索引的方法获取序列中的任意元素,__getitem__()
则实现了能够通过索引的方法获取对象中的任意元素。此外,我们可以在__getitem__()
中实现数据预处理。
示例1
import torch
from torch.utils.data import Dataset
class TensorDataset(Dataset):
"""
TensorDataset继承Dataset, 重载了__init__(), __getitem__(), __len__()
实现将一组Tensor数据对封装成Tensor数据集
能够通过index得到数据集的数据,能够通过len,得到数据集大小
"""
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
def __len__(self):
return self.data_tensor.size(0)
# 生成数据
data_tensor = torch.randn(4, 3)
target_tensor = torch.rand(4)
# 将数据封装成Dataset
tensor_dataset = TensorDataset(data_tensor, target_tensor)
# 可使用索引调用数据
print(tensor_dataset[1])
# 输出:(tensor([-1.0351, -0.1004, 0.9168]), tensor(0.4977))
# 获取数据集大小
print(len(tensor_dataset))
# 输出:4
示例2
import os
from PIL import Image
from torch.utils.data import Dataset
class PatchDataset(Dataset):
def __init__(self, data_dir, transforms=None):
"""
:param data_dir: 数据集所在路径
:param transform: 数据预处理
"""
self.data_info = self.get_img_info(data_dir)
self.transforms = transforms
def __getitem__(self, item):
path_img, label = self.data_info[item]
image = Image.open(path_img).convert('RGB')
if self.transforms is not None:
image = self.transforms(image)
return image, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
path_dir = os.path.join(data_dir, 'train_dataset.txt')
data_info = []
with open(path_dir) as file:
lines = file.readlines()
for line in lines:
data_info.append(line.strip('\n').split(' '))
return data_info
其中, train_dataset.txt
中的内容为:
2. torch.utils.data.DataLoader
作用:
-
DataLoader
将Dataset
对象或自定义数据类的对象封装成一个迭代器; - 这个迭代器可以迭代输出
Dataset
的内容; - 同时可以实现多进程、shuffle、不同采样策略,数据校对等等处理过程。
__init__()
中的几个重要的输入:
-
dataset
:这个就是pytorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。 -
batch_size
:根据具体情况设置即可。 -
shuffle
:随机打乱顺序,一般在训练数据中会采用。 -
collate_fn
:是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。 -
batch_sampler
:从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。 -
sampler
:从代码可以看出,其和shuffle是互斥的,一般默认即可。 -
num_workers
:从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。 -
pin_memory
:注释写得很清楚了:pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them
. 也就是一个数据拷贝的问题。 -
timeout
:是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
代码示例(接示例1)
tensor_dataloader = DataLoader(tensor_dataset, # 封装的对象
batch_size=2, # 输出的batch size
shuffle=True, # 随机输出
num_workers=0) # 只有1个进程
# 以for循环形式输出
for data, target in tensor_dataloader:
print(data, target)
输出结果:
tensor([[ 0.7745, 0.2186, 0.1231],
[-0.1307, 1.5778, -1.2906]]) tensor([0.3749, 0.4659])
tensor([[-0.1605, 0.9359, 0.1314],
[-1.1694, 1.0986, -0.9927]]) tensor([0.8071, 0.8997])
3. transforms
3.1 pytorch官方API
transforms
主要实现对数据集的预处理、数据增强、转换成tensor等一系列操作,使用以下代码可导入transforms
文件。
from torchvision import transforms
transforms
主要用在Dataset
类构建过程中,整个流程如下所示:
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data_dir, transforms=None):
self.data_info = self.get_img_info(data_dir)
self.transforms = transforms
def __getitem__(self, item):
path_img, label = self.data_info[item]
image = Image.open(path_img).convert('RGB')
# 使用定义好的transforms,对数据进行处理
if self.transforms is not None:
image = self.transforms(image)
return image, label
def __len__(self):
return len(self.data_info)
train_transforms = transforms.Compose([transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5)])
train_dataset = MyDataset(data_dir, train_transforms)
from torchvision import transforms
transforms_train = transforms.Compose([transforms.Resize(40),
transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4832, 0.4856],
std=[0.2023, 0.2013, 0.2111])])
transforms_test = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4832, 0.4856],
std=[0.2023, 0.2013, 0.2111])])
接下来介绍transforms
的原理及用法。
3.1.1 transforms.Compose
Compose
类的作用是组合多个transforms函数,Compose
类的初始化函数中需要传入一个含有多种transform方法的列表,随后将图像逐一通过这些transform方法。
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
类中定义了__call__()
方法,作用是可以将类实例化后得到的对象当做函数来使用,比如:
class SquareNum():
def __call__(self, x):
return x ** 2
square_num = SquareNum()
print(square_num(2))
由于类SquareNum
实现了魔法方法__call__()
,那么square_num(2)
就是把对象名当做函数名来使用。
3.1.2 transforms.ToTensor
这个类的作用是将PIL Image
或numpy.ndarray
转换成tensor,在转换前会将调整维度,并进行单位化:Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
.. _references: https://github.com/pytorch/vision/tree/master/references/segmentation
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
3.2 自定义transforms
对于目标检测,在对原始图像进行数据增强时,需要同时对目标的边界框坐标做相应的调整;或者我们需要构建自己的数据增强方法,这个时候我们就需要自己定义transforms。
import random
from torchvision.transforms import functional as F
class Compose(object):
"""组合多个transform函数"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class ToTensor(object):
"""将PIL图像转为Tensor"""
def __call__(self, image, target):
image = F.to_tensor(image)
# target不需要对维度进行调整或单位化
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
class RandomHorizontalFlip(object):
"""随机水平翻转图像以及bboxes"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, image, target):
if random.random() < self.prob:
height, width = image.shape[-2:]
image = image.flip(-1) # 水平翻转图片
bbox = target["boxes"]
# bbox: xmin, ymin, xmax, ymax
bbox[:, [0, 2]] = width - bbox[:, [2, 0]] # 翻转对应bbox坐标信息
target["boxes"] = bbox
return image, target
对于图像分割,我们在做数据增强时同样需要自己定义transforms。
import numpy as np
from PIL import Image
import random
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
def pad_if_smaller(img, size, fill=0):
min_size = min(img.size)
if min_size < size:
ow, oh = img.size
padh = size - oh if oh < size else 0
padw = size - ow if ow < size else 0
img = F.pad(img, (0, 0, padw, padh), fill=fill)
return img
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class RandomResize(object):
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = F.resize(image, size)
target = F.resize(target, size, interpolation=Image.NEAREST)
return image, target
class RandomHorizontalFlip(object):
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, image, target):
if random.random() < self.flip_prob:
image = F.hflip(image)
target = F.hflip(target)
return image, target
class RandomCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
image = F.crop(image, *crop_params)
target = F.crop(target, *crop_params)
return image, target
class CenterCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = F.center_crop(image, self.size)
target = F.center_crop(target, self.size)
return image, target
class ToTensor(object):
def __call__(self, image, target):
image = F.to_tensor(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = F.normalize(image, mean=self.mean, std=self.std)
return image, target
参考资料:
[1]: Pytorch笔记05-自定义数据读取方式orch.utils.data.Dataset与Dataloader.
[2]: pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类.