- 1 什么是模型
- 2 模型处理包括哪些操作
- 2.1 网络模型库torchvision.models
- 2.2 自定义模型
- 2.3 加载预训练模型
- 2.4 模型保存与加载
- 2.5 多GPU训练网络保存与加载
- 2.6 模型训练和测试的两种模式
1 什么是模型
模型是神经网络训练优化后得到的结果,包含了神经网络骨架及学习得到的参数。
2 模型处理包括哪些操作
网络模型库、自定义模型、预训练模型的加载和模型保存、多GPU训练网络保存与加载、模型训练和测试的两种模式
2.1 网络模型库torchvision.models
torchvision.models库提供了众多经典的网络结构与预训练模型,例如VGG、ResNet和Inception等,利用这些模型可以快速搭建物体检测网络,不需要逐层手动实现。torchvision包与PyTorch相独立,需要通过pip指令进行安装,如下:
pip install torchvision
以VGG模型为例,在torchvision.models中,VGG模型的特征层与分类层分别用vgg.features与vgg.classifier来表示,每个部分是一个nn.Sequential结构,可以方便地使用与修改。
VGG16的特征层包括13个卷积、13个激活函数ReLU、5个池化,一共31层
VGG16的分类层包括3个全连接、2个ReLU、2个Dropout,一共7层
from torchvision import models
vgg = models.vgg16()
print(vgg.features)
print(vgg.classifier)
2.2 自定义模型
参考 神经网络工具箱torch.nn
2.3 加载预训练模型
为什么要进行预训练模型加载
对于计算机视觉的任务,包括物体检测,我们通常很难拿到很大的数据集,在这种情况下重新训练一个新的模型是比较复杂的,并且不容易调整,因此,Fine-tune(微调)是一个常用的选择。
什么是Fine-tune
所谓Fine-tune是指利用别人在一些数据集上训练好的预训练模型,在自己的数据集上训练自己的模型。
加载预训练模型的两种方法:
第一种使用torchvision.models中自带的预训练模型
from torchvision import models
vgg = models.vgg16(pretrained=True)
第二种使用本地预训练模型(或训练过的模型)
利用load_state_dict,遍历预训练模型的关键字,如果出现在了VGG中,则加载预训练参数
import torch
from torchvision import models
vgg = models.vgg16()
state_dict = torch.load("your model path")
# 利用load_state_dict,遍历预训练模型的关键字,如果出现在了VGG中,则加载预训练参数
vgg.load_state_dict({k:v for k,v in state_dict.items() if k in vgg.state_dict()})
或者
vgg = models.vgg16()
vgg_dict = vgg.state_dict()
pretrained_dict = torch.load("your model path")
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(vgg_dict[k]) == np.shape(v)}
vgg_dict.update(pretrained_dict)
vgg.load_state_dict(vgg_dict)
2.4 模型保存与加载
PyTorch 中保存模型的方式有许多种:
# 保存整个网络
torch.save(model, PATH)
# 保存网络中的参数, 速度快,占空间少
torch.save(model.state_dict(),PATH)
# 选择保存网络中的一部分参数或者额外保存其余的参数
torch.save({'state_dict': model.state_dict(), 'fc_dict':model.fc.state_dict(),
'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
PATH)
同样的,PyTorch 中读取模型参数的方式也有许多种:
# 读取整个网络
model = torch.load(PATH)
# 读取 Checkpoint 中的网络参数
model.load_state_dict(torch.load(PATH))
# 若 Checkpoint 中的网络参数与当前网络参数有部分不同,有以下两种方式进行加载:
# 1. 利用字典的 update 方法进行加载
Checkpoint = torch.load(Path)
model_dict = model.state_dict()
model_dict.update(Checkpoint)
model.load_state_dict(model_dict)
# 2. 利用 load_state_dict() 的 strict 参数进行部分加载
model.load_state_dict(torch.load(PATH), strict=False)
2.5 多GPU训练网络保存与加载
指定多卡训练的模型就不是原来模型的类型了,而是并行化后的模型:
由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module
可以打印出来看一下
并行化后的模型参数必须加载到并行化的模型中,没并行化的参数要加载到没并行化的模型中,不然会出bug。
模型并行化后,保存没有并行化的模型并加载
model = DefinedNetwork()
torch.save(model.module.state_dict(), 'model_name.pth') # 保存
model.load_state_dict(torch.load(PATH)) # 加载
2.6 模型训练和测试的两种模式
model.train()和model.eval()分别在训练和测试中都要写,它们的作用如下:
(1) model.train()
启用BatchNormalization和 Dropout,将BatchNormalization和Dropout置为True
(2) model.eval()
不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False
注意:在训练模块中千万不要忘了写model.train();在评估(或测试)模块千万不要忘了写model.eval()