打开PyTorch官网,官方文档中第一部分是PyTorch的核心模块,torchaudio是处理PyTorch语音的,torchtext是处理文本的,torchivision是处理图像的。
打开torchvision,tensorboard和transforms均来源于这里,torchvision分了好几个模块,包括Datasets即数据集的API文档,只要在写代码时指定相应数据集的参数,它就能去下载使用对应的数据集。
COCO数据集一般用于目标检查、语义分割;MINIST一般作为教科书中的入门数据集,是手写文字数据集;CIFAR一般用于物体识别。
torchvision.models中提供了最常用的一些神经网络模块,这些模块已经训练好了。其中有分类、语义分割、目标检测、视频分类等数据集。
torchvision中的tensorboard和transforms模块已经讲解过了。
CIFAR 10 Dataset
- root (string) – Root directory of dataset where directory
cifar-10-batches-py
exists or will be saved to if download is set to True. - train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
- transform (callable_,_ optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable_,_ optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- root表示数据集存在什么样的位置。
- train是bool变量,为true表示创建的是训练集,为false表示是测试集。
- transform表示想对数据集进行什么样的变换。
- target_transform是对target进行transform。
- download为true表示从网上自动下载数据集,为false则不会下载。
首先导入torchvision
import torchvison
调用datasets工具包,对CIFAR10数据集进行下载
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
我们可以对下载得到的数据集进行分析
print(test_set[0])
输出:
(<PIL.Image.Image image mode=RGB size=32x32 at 0x1C58D023F40>, 3)
第一部分为PIL的图片数据,第二部分的3代表一个target(类别)。
print(test_set.classes)
输出:['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
说明3代表的是猫这个类别。
知道了这些后,我们就可以用两个变量来获取test_set的数据
img, target = test_set[0]
print(img)
print(target)
输出:
<PIL.Image.Image image mode=RGB size=32x32 at 0x1C583EE91C0>
3
CIFAR10包含了60000张32×32分为10个类别的彩色图片,其中50000张是训练图像,10000张是测试图像。
利用Transforms进行类型转换
先实例化transform对象
dataset_transform = torchvision.transform.Compose([torchvision.transforms.ToTensor()])
然后可以在数据集的参数中加入transform参数
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
输出图片
print(test_set[0])
输出:tensor数据类型
和Tensorboard结合
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
打开tensorboard,可以发现生成了step1-10的图片。
其他数据集的参数也类似,可以通过官方文档查看数据集参数,但有些数据集下载的很慢,我们可以先设置download为True,然后打开下载地址用迅雷(或其他)下载速度快的软件下载,下载完成后保存到对应文件夹中,这时再运行代码,它就会调用已经下载好的压缩包进行解压,减少下载所需时间。