前言
之前已经简单讲述了PyTorch的Tensor、Autograd、torch.nn和torch.optim包,通过这些我们已经可以简单的搭建一个网络模型,但这是不够的,我们还需要大量的数据,众所周知,数据是深度学习的灵魂,深度学习的模型是由数据“喂”出来的,这篇我们来讲述一下数据的加载和预处理。
- 首先,我们要引入torch包
import torch
torch.__version__
一、数据的加载
PyTorch通过torch.utils.data对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。
1.1 Dataset
Dataset是一个抽象类,为了能够方便的读取,需要将要使用的数据包装为Dataset类。自定义的Dataset类需要继承它并且实现2个成员方法:
- 1.__getitem__():该方法定义用索引(0-len(self))获取一条数据或一个样本
- 2.__len__():该方法返回数据集的总长度
下面我们使用Kaggle上的一个竞赛bluebook for bulldozers自定义一个数据集,为了方便介绍,我们使用里面的数据字典来做说明
- 首先,我们需要引用相关的包
from torch.utils.data import Dataset
import pandas as pd
- 自定义一个数据集
#定义一个数据集
class BulldozerDataset(Dataset):
""" 数据集演示 """
def __init__(self, csv_file):
"""实现初始化方法,在初始化的时候将数据读载入"""
self.df=pd.read_csv(csv_file)
def __len__(self):
'''
返回df的长度
'''
return len(self.df)
def __getitem__(self, idx):
'''
根据 idx 返回一行数据
'''
return self.df.iloc[idx].SalePrice
- 至此,我们的数据集已经定义完成了,我们可以实例化一个对象来访问
ds_demo= BulldozerDataset('median_benchmark.csv')
- 我们可以直接使用如下命令查看数据集数据
# 前面我们已经实现了__len__方法,所以可以直接使用
len(ds_demo)
- 使用索引可以直接访问对应的数据
ds_demo[0]
自定义的数据集已经创建好了,下面我们使用官方提供的数据载入器,读取数据
1.2 DataLoader
DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小)、shuffle(是否进行shuffle操作)、num_workers(加载数据时使用几个子进程)。下面做一个简单的演示:
dl = torch.utils.data.DataLoader(ds_demo,batch_size = 10,shuffle = True,num_workers = 0)
DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据
idata=iter(dl)
print(next(idata))
常见的用法是使用for循环对其进行遍历
for i, data in enumerate(dl):
print(i,data)
# 为了节约空间,这里只循环一遍
break
至此,我们已经可以通过dataset定义数据集,并使用DataLorder载入和遍历数据集。
二、torchvision包
torchvision 是PyTorch中专门用来处理图像的库,PyTorch官网的安装教程中最后的pip install torchvision 就是安装这个包。
torchvision已经预先实现了常用图像数据集,包括前面使用过的CIFAR-10,ImageNet、COCO、MNIST、LSUN等数据集,可通过torchvision.datasets方便的调用。
- 这里总结一下torchvision已经预装的数据集:
数据集名称 |
MNIST |
COCO |
CIFAR-10 |
ImageNet |
Captions |
Detection |
LSUN |
ImageFolder |
Imagenet-12 |
STL10 |
SVHN |
PhotoTour |
PyTorch中自带的数据集由2个上层api提供,分别是torchvision和torchtext
- torchvision提供了对图像数据处理的相关数据和api
- 数据位置:torchvision.datasets;例如:torchvision.datasets.MNIST
- torchtext提供了对文本数据处理的相关数据和api
- 数据位置:torchtext.datasets;例如:torchtext.datasets.IMDB
下面我们做一个简单的演示
- 首先,我们要引入torchvision包
import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录
train=True, # 表示是否加载数据库的训练集,false的时候加载测试集
download=True, # 表示是否自动下载 MNIST 数据集
transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理
2.1 torchvision.models
torchvision不仅提供了常用的图像数据集,而且还提供了一些训练好的网络模型,可以加载之后直接使用,或者继续进行迁移学习。torchvision.models模块的子模块中包含以下模型:
网络模型 |
AlexNet |
VGG |
ResNet |
SqueezeNet |
DenseNet |
我们直接可以使用训练好的模型,当然这个与datasets相同,都是需要从服务器下载的。
- 首先,我们需要导入torchvision.models
import torchvision.models as models
- 直接使用
resnet18 = models.resnet18(pretrained=True)
2.2 torchvision.tranforms
transforms 模块提供了一般的图像转换操作类,用作数据处理和数据增强
- 首先,我们需要引入torchvision.tranforms,然后做一个简单的演示
from torchvision import transforms as transforms
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), #先四周填充0,在把图像随机裁剪成32*32
transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转
transforms.RandomRotation((-45,45)), #随机旋转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])
肯定有人会问:(0.485, 0.456, 0.406), (0.2023, 0.1994, 0.2010) 这几个数字是什么意思?
官方的这个帖子有详细的说明: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21 这些都是根据ImageNet训练的归一化参数,可以直接使用,我们认为这个是固定值就可以。
到这里,我们已经完成了PyTorch的基本内容介绍。
参考文献
https:///zergtant/pytorch-handbook/blob/master/chapter2