#一、Pytorch模型定义方式

1.1模型定义 三种方式

Module类是torch.nn模块里提供的一个模型构造类(nn.module),是所有神经网络模块的基类,可以继承它来定义;模型定义主要包括两个主要部分:各部分的初始化(_init_);数据流向定义(forward);

基于nn.module,我们可以通过Sequencetial,modulelist和ModuleDict 三种方式定义;

1.1.1Sequential

对应模块为nn.Sequential()

当模型的前向计算为简单串联各个层的计算时,Sequential 类可以通过更简单的方式定义模型。它能接收一个子模块的有序字典(OrderedDict)或者一系列子模块作为参数来逐一添加Module的实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算;

Sequential定义模型时只需将模型的层按顺序排列起来即可,根据层名不同,排列方式有两种:

直接排列

import torch.nn as nn
net = nn.Sequential (
         nn.Linear(784,256),
         nn.ReLu(),
         nn.Linear(256,10)
                    )
print(net)
Sequential (
   (0):Linear(in_features=784,out_features=256,bias=True)
   (1):ReLu()
   (2):Linear(in_features=256,out_features=10,bias=True)
   )

使用OrderedDict:

import collections
import torch.nn as nn
net2 = nn.Sequential ( collections.OrderedDict([
        ('fc1',nn.Linear(784,256),
        ('relu1',nn.ReLu()),
        ( 'fc2', nn.Linear(256,10))
       ]))
print(net2)
Sequential (
   (fc1):Linear(in_features=784,out_features=256,bias=True)
   (relu1):ReLu()
   (fc2):Linear(in_features=256,out_features=10,bias=True)
   )

此定义方式简单易读;但也会使模型定义丧失灵活性;

1.1.2ModuleList(模块为nn.ModuleList)

ModuleList接受一个子模块(或层)的列表作为输入,然后也能类似List那样进行append和extend操作。子模块的权重也会自动添加到网络;

net = nn.ModuleList ([ nn.Linear(784,256) ,  nn.ReLu()])
net.append(  nn.Linear(256,10))
print(net[-1])
print(net)
Linear (in_features=256,out_features=10,bias=True)
ModuleList(
   (0):Linear(in_features=784,out_features=256,bias=True)
   (1):ReLu()
   (2):Linear(in_features=256,out_features=10,bias=True)
   )

nn.ModuleList没有定义一个网络,只是将不同模块储存在一起;还需要经过forward函数指定各个层的先后顺序才算完成模型定义;用for循环即可实现。

1.1.3ModuleDict

其与ModuleList作用类似,只是ModuleList能更方便的为神经网络的层添加名称;

#二、利用模块块快速搭建复杂网络

2.1U-net介绍

U-net是分割模型的杰作,在以医学影像为代表的诸多领域有着广泛应用;组成该模块主要有以下几个部分:

1)每个子块内部的两次卷积;(Double Convolution)

2)左侧模型块之间的下采样连接,即最大池化;(Max pooling);

3)Up sampling;

4)输出层处理;(out convolution)

还包括模块之间的横向连接,输入和U-et底部的连接扥计算;

在实现该模块时 不必把每一层按序排列显式写出,应先定义模型块,再定义模型块直接的连接顺序和计算方法。基础部件对应上述四个模型块;(具体实现代码略)

使用写好的模型块,可以方便地组装U-net模型,通过模块化的方式实现代码复用;

#三、PyTorch修改模型(代码太多略敲)

3.1修改模型层

3.2添加额外输入

3.3添加额外输出

#四、模型的保存与读取

4.1模型存储格式

存储模型主要采用pkl,pt,pth三种格式;

4.2模型存储内容

一个PyTorch模型主要包含两个部分:模型结构和权重。其中模型是继承nn.Module的类,权重的数据结构是一个字典(key 是层名,value是权重向量),存储也分为两种形式:存储整个模型和只存储权重;