pytorch包含多种优化算法用于网络参数的更新,比如常用的SGD、Adam、LBFGS以及RMSProp等。使用中可以发现各种优化算法的使用方式几乎相同,是因为父类optimizer【1】定义了各个子类(即SGD等)的核心行为,下面是optimizer类注释:
class
其中首句“所有优化器的基类” 表明所有的优化器都必须继承optimizer类,下面来分析optimizer类的的各个实例函数。
1、初始化__init__()
def
优化器需要保存学习率等参数的值,所以optimizer类需要用实例属性来存储这些参数,也就是__init__()中的self.param_groups,下面的代码通过一个全连接网络来测试优化器的param_groups包含哪些参数:
net
得到:
[{
2x2的矩阵是net的权重矩阵,1x2为偏置矩阵,其余为优化器的其它参数,所以说param_groups保存了优化器的全部数据,这个下面的state_dict()不同。
2、优化器状态state_dict()
def
查看上一节定义的optimizer的state_dict():
print
可以到优化器的完整参数如下:
[{
3、优化器参数加载load_state_dict()
上一节中的state_dict()负责提取优化器的参数,可以保存到本地用于下次训练恢复使用,对应的必然有load_state_dict()用于优化器参数的加载,其源码如下:
def
为了测试state_dict()和load_state_dict(),可以首先存储一个学习率为100的优化器的参数到本地:
optimizer_old
现在这个优化器的参数已经存储到本地,然后将这个优化器参数重新加载给一个新的学习率为0.01优化器:
optimizer_new
得到new优化器的学习率不是0.01,而是old优化器的学习率100:
[{
4、梯度清空zero_grad()
在网络优化过程中optimizer.zero_grad()函数需要被显式调用,负责清空其关联网络的参数梯度值,其源码如下:
def
这个遍历过程就是获取optimizer的param_groups属性的字典,之中的["params"],之中的所有参数,通过遍历设定每个参数的梯度值为0。
5、单步更新step()
def
优化器的step()函数负责更新参数值,但是其具体实现对于不同的优化算法是不同的,所以optimizer类只是定义了这种行为,但是并没有给出具体实现。
6、总结
优化算法部分的代码并不多,但是不同的优化算法涉及的概念较多,看懂各种算法的实现需要很强的数学功底。optimizer类定义了各种优化算法的公共行为与抽象方法,是典型的面向对象的继承思想。
参考:
【1】https:///pytorch/pytorch/blob/master/torch/optim/optimizer.py