- 参考:动手学深度学习
- 注意:由于本文是jupyter文档转换来的,代码不一定可以直接运行,有些注释是jupyter给出的交互结果,而非运行结果!!
文章目录
- 1. 读写 Tensor
- 2. 读写 Module
- 2.1 `state_dict`
- 2.2 保存和加载模型
- 2.2.1 保存和加载 `state_dict`(推荐)
- 2.2.2 保存和加载整个模型
1. 读写 Tensor
- pytorch 提供了
torch.save
函数和 torch.load
函数,分别用于存储和读取 Tensor,其中
-
torch.save
使用 Python 的 pickle 持续化模块将对象进行序列化并保存到本地磁盘,torch.save
可以保存各种对象,包括模型、张量和字典等 -
torch.load
使用 pickle unpickle 工具将 pickle 的对象文件反序列化为内存变量
- 下面创建 Tensor 变量
x
,并将保存为本地文件 x.pt
,然后再读取回来
- 类似地,存储Tensor列表和字典并读回内存
2. 读写 Module
2.1 state_dict
- 保存/加载 Module 的一个思路是保存和加载其所有参数。PyTorch 中 Module 的可学习参数包括权重 weight 和偏置 bias,它们可以通过
.parameters()
或 .named_parameter()
方法访问 - 调用 Module 的
.state_dict()
方法,返回一个从参数名称(“layer名.weight” 或 “layer名.bias”)映射到到参数 Tesnor 的字典对象,其中包含了模型的所有可学习参数
可见,只有具有可学习参数的层(卷积层、线性层等)才有 state_dict 中的条目;另外,优化器(optim)也有一个 state_dict,其中包含关于优化器状态以及所使用的超参数的信息
2.2 保存和加载模型
- PyTorch 中保存和加载训练模型有两种常见的方法:
- 仅保存和加载模型参数(state_dict),这是推荐方式
- 保存和加载整个模型
2.2.1 保存和加载 state_dict
(推荐)
- 通过保存和加载参数来实现模型的存取,相比直接保存整个
Module
对象轻量很多
- 保存时,使用
torch.save
保存模型的 state_dict
字典实例 - 加载时,先使用
torch.load
加载模型的 state_dict
,然后实例化一个所需类型的 Module
,最后调用 Moudule 的 .load_state_dict
方法加载参数
- 给出一个读写多层感知机模型的示例
2.2.2 保存和加载整个模型
- 第 1 节说明了
torch.save
可以保存各种对象,包括模型、张量和字典等 ,所以也可以直接对 Module
实例使用 torch.save
和 torch.load
进行保存加载