pytorch Fashion-MNIST分类实现指南

1. 简介

在本指南中,我将教会你如何使用PyTorch对Fashion-MNIST数据集进行分类。Fashion-MNIST是一个包含10个类别的衣物图像数据集,每个类别有6000张大小为28x28的灰度图像。我们的目标是建立一个模型,能够对这些图像进行准确的分类。

2. 整体流程

下面是实现这个任务的整体流程:

erDiagram
    Customer --> Extract_Data : 1. 提取数据
    Extract_Data --> Preprocess_Data : 2. 预处理数据
    Preprocess_Data --> Split_Data : 3. 划分数据集
    Split_Data --> Define_Model : 4. 定义模型
    Define_Model --> Train_Model : 5. 训练模型
    Train_Model --> Evaluate_Model : 6. 评估模型
    Evaluate_Model --> Use_Model : 7. 使用模型

3. 步骤详解

3.1 提取数据

首先,我们需要从Fashion-MNIST数据集中提取数据。PyTorch提供了一个内置的torchvision.datasets模块,其中包含了一些常用的数据集,包括Fashion-MNIST。

import torchvision.datasets as datasets

# 提取Fashion-MNIST数据集
train_data = datasets.FashionMNIST(root='data', train=True, download=True)
test_data = datasets.FashionMNIST(root='data', train=False, download=True)

在这段代码中,train_datatest_data分别是训练集和测试集的数据对象。

3.2 预处理数据

接下来,我们需要对数据进行预处理。首先,我们将图像数据转换为张量,并将像素值归一化到[0, 1]范围内。

import torchvision.transforms as transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 对数据进行转换
train_data.transform = transform
test_data.transform = transform

3.3 划分数据集

在训练模型之前,我们需要将数据集划分为训练集和验证集。我们可以使用torch.utils.data.random_split函数来实现这一步骤。

from torch.utils.data import random_split

# 定义训练集和验证集的比例
train_ratio = 0.8

# 计算训练集和验证集的数据量
train_size = int(train_ratio * len(train_data))
val_size = len(train_data) - train_size

# 划分数据集
train_dataset, val_dataset = random_split(train_data, [train_size, val_size])

3.4 定义模型

现在,我们需要定义一个模型来进行分类。在这个示例中,我们将使用一个简单的卷积神经网络作为模型。

import torch.nn as nn

# 定义卷积神经网络模型
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 创建模型实例
model = CNNModel()

3.5 训练模型

现在,我们可以开始训练模型了。我们需要定义损失函数和优化器,并对模型进行多次迭代训练。

import torch.optim as optim

# 定义损失函数和优化