项目结构总结
一般项目都包含以下几个部分:
模型定义
数据处理和加载
训练模型(Train&Validate)
训练过程的可视化
测试(Test/Inference)
主要目录结构:
- checkpoints/: 用于保存训练好的模型,可使程序在异常退出后仍能重新载入模型,恢复训练
- data/:数据相关操作,包括数据预处理、dataset实现等
- models/:模型定义,可以有多个模型,例如上面的AlexNet和ResNet34,一个模型对应一个文件
- utils/:可能用到的工具函数,在本次实验中主要是封装了可视化工具
- config.py:配置文件,所有可配置的变量都集中在此,并提供默认值
- main.py:主文件,训练和测试程序的入口,可通过不同的命令来指定不同的操作和参数
- requirements.txt:程序依赖的第三方库
- README.md:提供程序的必要说明
PyTorch | 项目结构解析www.cnblogs.com
###模型定义
1.必须继承nn.Module这个类,要让PyTorch知道这个类是一个Module
2.在init(self)中设置好需要的"组件"(如conv,pooling,Linear,BatchNorm等),第一行用super(Net,self).__init__()实现父类的初始化
3.最后,在forward(self,x)中定义好的“组件”进行组装,就像搭积木,把网络结构搭建出来,这样一个模型就定义好了。
#一个简单的模型
import torch
import torch.nn as nn
import torch.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()#实现父类的初始化
self.conv1=nn.Conv2d(3,6,5)#定义卷积层组件
self.pool1=nn.MaxPool2d(2,2)#定义池化层组件
self.conv2=nn.Conv2dn(6,16,5)
self.pool2=nn.MaxPool2d(2,2)
self.fc1=nn.Linear(16*5*5,120)#定义线性连接
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)
def forward(self,x):#x模型的输入
x=self.pool1(F.relu(self.conv1(x)))
x=self.pool2(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5)#表示将x进行reshape,为后面做为全连接层的输入
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x
其它常用方法:
- nn.Sequetial
torch.nn.Sequential其实就是Sequential容器,该容器将一系列操作按先后顺序给包起来,方便重复使用。
如何定义PyTorch模型www.jianshu.com
- torch.functional
需要自行定义参数
【PyTorch学习笔记】17:2D卷积,nn.Conv2d和F.conv2d_LauZyHou的笔记-CSDN博客_f.conv2dblog.csdn.net
- torch.nn
通过functional的函数实现的类,实现是定义了variable等参数
torch.nn 和 torch.functional 的区别blog.csdn.net
###数据处理和加载
- Dataset类
torch.utils.data.Dataset
Dataset
类是Pytorch
中图像数据集中最为重要的一个类,也是Pytorch
中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示:
- def getitem(self, index):
- def len(self):
其中__len__
应该返回数据集的大小,而__getitem__
应该编写支持数据集索引的函数,例如通过dataset[i]
可以得到数据集中的第i+1
个数据。
- DataLoader类
torch.utils.data.DataLoader()
之前所说的Dataset
类是读入数据集数据并且对读入的数据进行了索引。但是光有这个功能是不够用的,在实际的加载数据集的过程中,我们的数据量往往都很大,对此我们还需要一下几个功能:
- 可以分批次读取:batch-size
- 可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序
- 可以并行加载数据(利用多核处理器加快载入数据的效率)
Pytorch中正确设计并加载数据集方法 - pytorch中文网ptorch.com
- imageFolder类
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
- root:在root指定的路径下寻找图片
- transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
- target_transform:对label的转换
- loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象
pytorch读取数据集_qq_36852276的博客-CSDN博客_pytorch读取数据集blog.csdn.net
PyTorch-ImageFolder/自定义类 读取图片数据blog.csdn.net
###训练模型
定义损失函数和优化器
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
前向传播,后向传播
for t in range(epoch):
for step, (x, y) in enumerate(train_loader):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)
# Compute and print loss
loss = criterion(y_pred, y) # 计算损失函数
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad() # 梯度置零,因为反向传播过程中梯度会累加上一次循环的梯度
loss.backward() # loss反向传播
optimizer.step() # 反向传播后参数更新
PyTorch训练模型小结_KAMITA的博客-CSDN博客_pytorch 训练模型blog.csdn.net
###可视化
可视化工具tensorboardX
- 创建一个 SummaryWriter 的示例
from tensorboardX import SummaryWriter
# Creates writer1 object.
# The log will be saved in 'runs/exp'
writer1 = SummaryWriter('runs/exp')
# Creates writer2 object with auto generated file name
# The log directory will be something like 'runs/Aug20-17-20-33'
writer2 = SummaryWriter()
# Creates writer3 object with auto generated file name, the comment will be appended to the filename.
# The log directory will be something like 'runs/Aug20-17-20-33-resnet'
writer3 = SummaryWriter(comment='resnet')
2. 接下来,我们就可以调用 SummaryWriter 实例的各种add_something
方法向日志中写入不同类型的数据了。想要在浏览器中查看可视化这些数据,只要在命令行中开启 tensorboard 即可
tensorboard --logdir=<your_log_dir>
- add_scalar(tag, scalar_value, global_step=None, walltime=None)
tag (string): 数据名称,不同名称的数据使用不同曲线展示
scalar_value (float): 数字常量值
global_step (int, optional): 训练的 step
walltime (float, optional): 记录发生的时间,默认为 time.time()
需要注意,这里的scalar_value一定是 float 类型,如果是 PyTorch scalar tensor,则需要调用.item()方法获取其数值。
- add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None)
tag (string): 数据名称
values (torch.Tensor, numpy.array, or string/blobname): 用来构建直方图的数据
global_step (int, optional): 训练的 step
bins (string, optional): 取值有 ‘tensorflow’、‘auto’、‘fd’ 等, 该参数决定了分桶的方式,详见这里。
walltime (float, optional): 记录发生的时间,默认为 time.time()
max_bins (int, optional): 最大分桶数
- add_graph(model, input_to_model=None, verbose=False, **kwargs)
model (torch.nn.Module): 待可视化的网络模型
input_to_model (torch.Tensor or list of torch.Tensor, optional): 待输入神经网络的变量或一组变量
详解PyTorch项目使用TensorboardX进行训练可视化_浅度寺-CSDN博客_tensorboardxblog.csdn.net
visdom
Visdom 介绍|上baijiahao.baidu.com
###测试
sklearn