- pytorch使用torch.utils.data对常用的数据加载进行封装,可以实现多线程预读取和批量加载。
- 主要包括两个方面:1)把数据包装成
Dataset
类;2)用DataLoader
加载。 -
TensorDataset
可以直接接受Tensor
类型的输入,并用DataLoader
进行加载;省去自定义的过程。
官方数据集
-
torchvision
中实现了一些常用的数据集,可以通过torchvision.datasets
直接调用。如:MNIST,COCO,Captions,Detection,LSUN,ImageFolder,Imagenet-12,CIFAR,STL10,SVHN,PhotoTour。 -
torchvision.transforms
提供了许多图像操作,可以很方便的进行数据增强。
一个典型的CIFAR10数据加载过程如下:
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 可以加入更多数据增强处理
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
自定义数据加载
如果不使用官方数据集,想要加载自己的数据集,需要自定义一个dataset类。它需要继承torch.utils.data.Dataset
类,并实现__getitem__()
和__len__()
两个成员方法。
下面是一个自定义的视频数据集的例子:
from torch.utils.data import Dataset
class FrameDataset(Dataset):
# 初始化时候要有一个数据加载,可以是数据路径的列表,或者直接把数据全加载进来。
# 后者实际上没有用到批量加载的功能,需要注意内存占用。
def __init__(self, data_dir, transform):
with open(data_dir, 'r') as fr:
reader = csv.reader(fr)
self.video_files = [video for video, label in reader]
self.transform = transform
print("dataset size: ", len(self.video_files))
def __getitem__(self, index):
video_file = self.video_files[index]
# 读取视频的帧并返回
return imgs, num_img
def __len__(self):
return len(self.video_files)
注意:
- 如果数据类型是图像、视频,我们可以把原始数据保存在一个文件夹中,再用一个列表保存图像或视频的路径。这样数据集初始化时候加载的其实只是路径列表,在训练和测试时才会分批把原始数据读入。
- 如果原始数据是直接以数字形式存储在一个文件中,无法通过索引单个读取,可以在初始化时候把整个矩阵读入,然后每次getitem时返回其中一行。
自定义的一大优点是处理更灵活,例如对于视频或文本数据,getitem函数中返回的帧序列或句子序列往往是长度不固定的,默认情况下DataLoader
在stack
时会出错,这时可以用collate_fn
指定batch数据的连接方式:
def collate_fn(batch):
imgs, num_img = zip(*batch)
return torch.cat(imgs), num_img
然后就可以正常加载数据了:
dataset = FrameDataset(csv_file, transform=tfms)
videoloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
numpy类型数据加载
如果数据需要一次性全部读入,而且不需要额外的复杂处理的话可以不用自定义数据集Dataset类。
比如通常情况下,我们的输入可以很容易处理成一个numpy类型。这时可以不用定义Dataset类,直接使用TensorDataset
,只要把读入的数据转化成一个tensor传入即可。
random_split
是一个可以自动划分数据集的函数,实现随机不重复划分的功能。
from torch.utils.data import TensorDataset,DataLoader,random_split
dataset = TensorDataset(torch.from_numpy(data))
n_train = int(len(dataset) * 0.9)
n_test = len(dataset) - n_train
trainset, testset = random_split(dataset, [n_train, n_test])
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)