首先通过pytorch的官方文档来确认torchvision中所支持的数据集和构造函数
​​​https://pytorch.org/vision/stable/datasets.html​

所支持的数据集如下图所示。

trochvision中数据集的使用_计算机视觉


在这里我们使用的是CIFAR-10的数据集作为测试样本。

trochvision中数据集的使用_深度学习_02


通过阅读官方文档,我们能够得到CIFAR10的几个参数。

第一个参数:root,类型str,这意味着我们数据集的保存路径,如果我们指定了download参数,那么数据集将会自动的下载到我们的指定的root目录下。

train参数如果为True,那么意味着我们的这个数据是一个训练数据集,否则就是一个测试数据集。

transforms参数这意味着我们要对这个图像进行的transfroms变换,这个变换包括了中心裁剪、格式转换等操作,这是一个回调函数参数,通过制定一个函数来执行我们的transforms的动作。

download参数如果为True,那么如果我们的root参数指定的目录下没有数据集文件,就会自动下载数据集文件。

trochvision中数据集的使用_数据集_03

1.导入数据包

#coding=utf-8
from cv2 import transform
import torch
import torchvision
from tensorboard import *
from torch.utils.tensorboard import SummaryWriter

2.下载数据集

train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=None,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=None,download=True)

运行后,脚本将会自动的下载我们的数据集到dataset目录下。

trochvision中数据集的使用_数据_04

3.确认返回数据

trochvision中数据集的使用_计算机视觉_05


根据官方文档描述,返回的信息继承了getitem方法,同时,返回的数据是一个tuple类型的数据,包括了image图像信息和target的数据信息。我们通过中括号的方式调用getitem魔术方法。

print(test_set[0])
img,target = test_set[0]
print(type(img))

trochvision中数据集的使用_深度学习_06


我们能够发现最终得到的数据是一个PIL.image格式的数据。

4.tensor转换

根据tensorboard的要求,我们要使用SummaryWriter类需要传入一个tensor类型的数据,因此,我们需要使用dataset.CIFAR10提供的transform回调函数来对我们的PIL.Images进行处理。
因此,我们将代码修改如下:

dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
print(test_set[0])
img,target = test_set[0]
print(type(img))

trochvision中数据集的使用_计算机视觉_07


能够发现我们的图片数据已经被我们改为了tensor类型。

5.使用tensorboard进行可视化

补充如下代码:

writer = SummaryWriter("logs")
for i in range(len(test_set)):
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()

6.得到如下结果

trochvision中数据集的使用_数据集_08