项目结构总结

一般项目都包含以下几个部分:

模型定义
数据处理和加载
训练模型(Train&Validate)
训练过程的可视化
测试(Test/Inference)

主要目录结构:



- checkpoints/: 用于保存训练好的模型,可使程序在异常退出后仍能重新载入模型,恢复训练
- data/:数据相关操作,包括数据预处理、dataset实现等
- models/:模型定义,可以有多个模型,例如上面的AlexNet和ResNet34,一个模型对应一个文件
- utils/:可能用到的工具函数,在本次实验中主要是封装了可视化工具
- config.py:配置文件,所有可配置的变量都集中在此,并提供默认值
- main.py:主文件,训练和测试程序的入口,可通过不同的命令来指定不同的操作和参数
- requirements.txt:程序依赖的第三方库
- README.md:提供程序的必要说明


PyTorch | 项目结构解析www.cnblogs.com

###模型定义

1.必须继承nn.Module这个类,要让PyTorch知道这个类是一个Module
2.在init(self)中设置好需要的"组件"(如conv,pooling,Linear,BatchNorm等),第一行用super(Net,self).__init__()实现父类的初始化
3.最后,在forward(self,x)中定义好的“组件”进行组装,就像搭积木,把网络结构搭建出来,这样一个模型就定义好了。



#一个简单的模型
import torch
import torch.nn as nn
import torch.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()#实现父类的初始化
        self.conv1=nn.Conv2d(3,6,5)#定义卷积层组件
        self.pool1=nn.MaxPool2d(2,2)#定义池化层组件
        self.conv2=nn.Conv2dn(6,16,5)
        self.pool2=nn.MaxPool2d(2,2)
        self.fc1=nn.Linear(16*5*5,120)#定义线性连接
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
    def forward(self,x):#x模型的输入
        x=self.pool1(F.relu(self.conv1(x)))
        x=self.pool2(F.relu(self.conv2(x)))
        x=x.view(-1,16*5*5)#表示将x进行reshape,为后面做为全连接层的输入
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
    return x



其它常用方法:

  • nn.Sequetial

torch.nn.Sequential其实就是Sequential容器,该容器将一系列操作按先后顺序给包起来,方便重复使用。


如何定义PyTorch模型www.jianshu.com


android端 pytorch 训练 pytorch训练_android端 pytorch 训练


  • torch.functional

需要自行定义参数


【PyTorch学习笔记】17:2D卷积,nn.Conv2d和F.conv2d_LauZyHou的笔记-CSDN博客_f.conv2dblog.csdn.net

android端 pytorch 训练 pytorch训练_数据_02


  • torch.nn

通过functional的函数实现的类,实现是定义了variable等参数

torch.nn 和 torch.functional 的区别blog.csdn.net


android端 pytorch 训练 pytorch训练_2d_03


###数据处理和加载

  • Dataset类

torch.utils.data.Dataset

Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示:

  • def getitem(self, index):
  • def len(self):

其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。

  • DataLoader类

torch.utils.data.DataLoader()

之前所说的Dataset类是读入数据集数据并且对读入的数据进行了索引。但是光有这个功能是不够用的,在实际的加载数据集的过程中,我们的数据量往往都很大,对此我们还需要一下几个功能:

  • 可以分批次读取:batch-size
  • 可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序
  • 可以并行加载数据(利用多核处理器加快载入数据的效率)

Pytorch中正确设计并加载数据集方法 - pytorch中文网ptorch.com

android端 pytorch 训练 pytorch训练_2d_04


  • imageFolder类
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
- root:在root指定的路径下寻找图片
- transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
- target_transform:对label的转换
- loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

pytorch读取数据集_qq_36852276的博客-CSDN博客_pytorch读取数据集blog.csdn.net

android端 pytorch 训练 pytorch训练_数据_05

PyTorch-ImageFolder/自定义类 读取图片数据blog.csdn.net

android端 pytorch 训练 pytorch训练_pytorch 训练_06


###训练模型

定义损失函数和优化器


criterion = torch.nn.MSELoss(reduction='sum') 
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)


前向传播,后向传播


for t in range(epoch):
    for step, (x, y) in enumerate(train_loader):
        # Forward pass: Compute predicted y by passing x to the model
        y_pred = model(x)
 
        # Compute and print loss
        loss = criterion(y_pred, y) # 计算损失函数

        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad() # 梯度置零,因为反向传播过程中梯度会累加上一次循环的梯度
        loss.backward() # loss反向传播
        optimizer.step() # 反向传播后参数更新

PyTorch训练模型小结_KAMITA的博客-CSDN博客_pytorch 训练模型blog.csdn.net

android端 pytorch 训练 pytorch训练_数据_07


###可视化

可视化工具tensorboardX

  1. 创建一个 SummaryWriter 的示例
from tensorboardX import SummaryWriter

# Creates writer1 object.
# The log will be saved in 'runs/exp'
writer1 = SummaryWriter('runs/exp')

# Creates writer2 object with auto generated file name
# The log directory will be something like 'runs/Aug20-17-20-33'
writer2 = SummaryWriter()

# Creates writer3 object with auto generated file name, the comment will be appended to the filename.
# The log directory will be something like 'runs/Aug20-17-20-33-resnet'
writer3 = SummaryWriter(comment='resnet')


2. 接下来,我们就可以调用 SummaryWriter 实例的各种add_something方法向日志中写入不同类型的数据了。想要在浏览器中查看可视化这些数据,只要在命令行中开启 tensorboard 即可

tensorboard --logdir=<your_log_dir>

  • add_scalar(tag, scalar_value, global_step=None, walltime=None)
tag (string): 数据名称,不同名称的数据使用不同曲线展示
scalar_value (float): 数字常量值
global_step (int, optional): 训练的 step
walltime (float, optional): 记录发生的时间,默认为 time.time()


需要注意,这里的scalar_value一定是 float 类型,如果是 PyTorch scalar tensor,则需要调用.item()方法获取其数值。

  • add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None)
tag (string): 数据名称
values (torch.Tensor, numpy.array, or string/blobname): 用来构建直方图的数据
global_step (int, optional): 训练的 step
bins (string, optional): 取值有 ‘tensorflow’、‘auto’、‘fd’ 等, 该参数决定了分桶的方式,详见这里。
walltime (float, optional): 记录发生的时间,默认为 time.time()
max_bins (int, optional): 最大分桶数


  • add_graph(model, input_to_model=None, verbose=False, **kwargs)
model (torch.nn.Module): 待可视化的网络模型
input_to_model (torch.Tensor or list of torch.Tensor, optional): 待输入神经网络的变量或一组变量

详解PyTorch项目使用TensorboardX进行训练可视化_浅度寺-CSDN博客_tensorboardxblog.csdn.net

android端 pytorch 训练 pytorch训练_数据_08


visdom

Visdom 介绍|上baijiahao.baidu.com

android端 pytorch 训练 pytorch训练_数据集_09


###测试

sklearn