pytorch包含多种优化算法用于网络参数的更新,比如常用的SGD、Adam、LBFGS以及RMSProp等。使用中可以发现各种优化算法的使用方式几乎相同,是因为父类optimizer【1】定义了各个子类(即SGD等)的核心行为,下面是optimizer类注释:


class


其中首句“所有优化器的基类” 表明所有的优化器都必须继承optimizer类,下面来分析optimizer类的的各个实例函数。

1、初始化__init__()


def


优化器需要保存学习率等参数的值,所以optimizer类需要用实例属性来存储这些参数,也就是__init__()中的self.param_groups,下面的代码通过一个全连接网络来测试优化器的param_groups包含哪些参数:


net


得到:


[{


2x2的矩阵是net的权重矩阵,1x2为偏置矩阵,其余为优化器的其它参数,所以说param_groups保存了优化器的全部数据,这个下面的state_dict()不同。

2、优化器状态state_dict()


def


查看上一节定义的optimizer的state_dict():


print


可以到优化器的完整参数如下:


[{


3、优化器参数加载load_state_dict()

上一节中的state_dict()负责提取优化器的参数,可以保存到本地用于下次训练恢复使用,对应的必然有load_state_dict()用于优化器参数的加载,其源码如下:


def


为了测试state_dict()和load_state_dict(),可以首先存储一个学习率为100的优化器的参数到本地:


optimizer_old


现在这个优化器的参数已经存储到本地,然后将这个优化器参数重新加载给一个新的学习率为0.01优化器:


optimizer_new


得到new优化器的学习率不是0.01,而是old优化器的学习率100:


[{


4、梯度清空zero_grad()

在网络优化过程中optimizer.zero_grad()函数需要被显式调用,负责清空其关联网络的参数梯度值,其源码如下:


def


这个遍历过程就是获取optimizer的param_groups属性的字典,之中的["params"],之中的所有参数,通过遍历设定每个参数的梯度值为0。

5、单步更新step()


def


优化器的step()函数负责更新参数值,但是其具体实现对于不同的优化算法是不同的,所以optimizer类只是定义了这种行为,但是并没有给出具体实现。

6、总结

优化算法部分的代码并不多,但是不同的优化算法涉及的概念较多,看懂各种算法的实现需要很强的数学功底。optimizer类定义了各种优化算法的公共行为与抽象方法,是典型的面向对象的继承思想。

参考:

【1】https:///pytorch/pytorch/blob/master/torch/optim/optimizer.py