PyTorch MNIST 获取图片实现流程
本文将教会你如何使用PyTorch获取MNIST图片数据集。在开始之前,确保你已经安装了PyTorch和相关的依赖项。
整体流程
下面是整个任务的流程概述。我们将按照以下步骤进行实现:
erDiagram
方框1 --|> 方框2 : 步骤1
方框2 --|> 方框3 : 步骤2
方框3 --|> 方框4 : 步骤3
方框4 --|> 方框5 : 步骤4
方框5 --|> 方框6 : 步骤5
方框6 --|> 方框7 : 步骤6
方框7 --|> 方框8 : 步骤7
方框8 --|> 方框9 : 步骤8
方框9 --|> 方框10 : 步骤9
步骤1:导入必要的库
首先,我们需要导入所需的库,包括PyTorch和相关的模块。代码如下:
import torch
from torchvision import datasets, transforms
在这里,我们导入了torch和torchvision库。torchvision库提供了一些常用的计算机视觉数据集和转换操作。
步骤2:定义数据转换
在这一步中,我们将定义对MNIST数据集进行的数据转换操作。代码如下:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
这里,我们使用transforms.Compose
函数将多个转换操作组合到一起。我们使用了transforms.ToTensor()
将图像转换为张量,并使用transforms.Normalize()
对图像进行标准化。
步骤3:加载数据集
接下来,我们将加载MNIST数据集。代码如下:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
这里,我们使用datasets.MNIST
类加载MNIST数据集。我们通过设置train=True
和train=False
来指定训练集和测试集。root
参数指定数据集的路径,transform
参数指定数据集的转换操作,download=True
表示如果数据集不存在,则自动下载。
步骤4:创建数据加载器
数据加载器是PyTorch中用于加载数据的对象。我们将创建训练集和测试集的数据加载器。代码如下:
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
这里,我们使用torch.utils.data.DataLoader
类创建数据加载器。我们将训练集和测试集作为参数传递给数据加载器,并指定批次大小和是否打乱数据。
步骤5:可视化数据
现在,我们将通过可视化数据来检查数据是否加载正确。代码如下:
import matplotlib.pyplot as plt
# 获取一个批次的图像和标签
images, labels = next(iter(train_loader))
# 可视化图像
fig = plt.figure(figsize=(10, 10))
for i in range(64):
ax = fig.add_subplot(8, 8, i+1)
ax.imshow(images[i].squeeze(), cmap='gray')
ax.set_title(labels[i].item())
ax.axis('off')
plt.show()
这里,我们使用next(iter(train_loader))
获取一个批次的图像和标签。然后,我们使用matplotlib.pyplot
库将图像可视化。
步骤6:建立模型
在这一步中,我们将建立一个简单的卷积神经网络模型。代码如下:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3