模型在保存时侯以键对值保存,同时在加载时根据现在网络的键值查找模型对应的键值,然后加载。一般报错是因为模型和网络的键值不匹配。
1、最常见的问题是键值多了或者少了 module.
此种情况是模型在DataParallel或者DDP训练后保存的键值有module. ,对应的网络的键值则没有module.
1)可以通过:
model = nn.DataParallel(model)
将模型的键值加上module.
如:加载模型时删除多余的module. 代码如下
2、详解load_state_dict(state_dict, False)的False参数
很多教程说名字不匹配直接添加False参数即可,但是这里需要注意一个大坑。
如果模型的键值和网络的键值完全不匹配,那么模型就没有加载预训练参数,虽然不再报错。
该False参数作用在于 非严格匹配加载模型,可以下面几种情况进行分析。
1)模型包含网络的部分参数
比如说模型是resnet101模型,你现在的网络是resnet50。再假设resnet50的参数名包含在resnet101的参数中,那么直接使用False会为你的网络resnet50加载键值相同的参数。这样就避免了对resnet101的每个键对值进行循环匹配,看是否是resnet50需要的。
2)模型完全不包含网络的参数
情况如1,模型有100个参数,都包含'module.' ,网络也有100个参数,都没有'module.' 。这种情况下如果参数设置为False,会发现没有任何键值能匹配上,因此网络就不会加载任何参数。
3)再介绍一个False使用场景
比如蒸馏网络PISR中,教师网络包含Encoder和Decoder两部分,学生网络由其中的Decoder部分组成,所以在训练学生网络时,如果要加载教师网络保存的预训练模型,设置False会自动识别Decoder部分键值相同,然后加载。
综上,设置False参数后依旧是按照键值查询加载参数的,有多少键值匹配,就加载多少模型的参数。
3、只要参数尺寸相同,就能加载
比如说我有一个10层网络的模型,还有一个3层的网络。我想把其中第9层的参数加载到现在网络的1层。如果参数的尺寸相同,就可以遍历键对值。将参数加载到想要的键值中。
更多参考
模型在保存时侯以键对值保存,同时在加载时根据现在网络的键值查找模型对应的键值,然后加载。一般报错是因为模型和网络的键值不匹配。
1、最常见的问题是键值多了或者少了 module.
此种情况是模型在DataParallel或者DDP训练后保存的键值有module. ,对应的网络的键值则没有module.
1)可以通过:
model = nn.DataParallel(model)
将模型的键值加上module.
如:加载模型时删除多余的module. 代码如下
2、详解load_state_dict(state_dict, False)的False参数
很多教程说名字不匹配直接添加False参数即可,但是这里需要注意一个大坑。
如果模型的键值和网络的键值完全不匹配,那么模型就没有加载预训练参数,虽然不再报错。
该False参数作用在于 非严格匹配加载模型,可以下面几种情况进行分析。
1)模型包含网络的部分参数
比如说模型是resnet101模型,你现在的网络是resnet50。再假设resnet50的参数名包含在resnet101的参数中,那么直接使用False会为你的网络resnet50加载键值相同的参数。这样就避免了对resnet101的每个键对值进行循环匹配,看是否是resnet50需要的。
2)模型完全不包含网络的参数
情况如1,模型有100个参数,都包含'module.' ,网络也有100个参数,都没有'module.' 。这种情况下如果参数设置为False,会发现没有任何键值能匹配上,因此网络就不会加载任何参数。
3)再介绍一个False使用场景
比如蒸馏网络PISR中,教师网络包含Encoder和Decoder两部分,学生网络由其中的Decoder部分组成,所以在训练学生网络时,如果要加载教师网络保存的预训练模型,设置False会自动识别Decoder部分键值相同,然后加载。
综上,设置False参数后依旧是按照键值查询加载参数的,有多少键值匹配,就加载多少模型的参数。
3、只要参数尺寸相同,就能加载
比如说我有一个10层网络的模型,还有一个3层的网络。我想把其中第9层的参数加载到现在网络的1层。如果参数的尺寸相同,就可以遍历键对值。将参数加载到想要的键值中。