TensorFlow实践(13)——保存和复用训练好的模型
- (一)前 言
- (二)保存训练好的模型
- (三)重载保存的模型
- (四)总结
(一)前 言
当模型训练完成之后,我们可以使用tf.train.Saver()方法将训练好的模型进行保存,以便于之后使用模型进行预测等任务,而不用重复训练。Saver构造方法的主要输入参数:
参数名称 | 功能说明 | 默认值 |
var_list | Saver存储的变量集合 | 全局变量集合 |
reshape | 是否允许从checkpoint文件中恢复时改变变量形状 | True |
sharded | 是否将checkpoint文件中的变量轮循放置在所有设备上 | True |
max_to_keep | 保留最近的检查点的个数 | 5 |
restore_sequentially | 是否按顺序恢复所有变量,当模型较大时顺序回复可以降低内存 | True |
(二)保存训练好的模型
接下来我们通过一个简单的例子演示如何保存训练好的模型:
import tensorflow as tf
# 创建一个TensorFlow变量
data = tf.Variable(tf.truncated_normal([2, 2]), name = 'data')
# 创建一个Saver,默认保存所有变量
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(data))
saver.save(sess, 'C:/Users/12394/PycharmProjects/Spyder/test/model.ckpt')
# 输出:
[[-0.03797581 -0.1358492 ]
[ 0.08680686 0.6192091 ]]
在对应的目录下,我们可以找到如下文件:
model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。
(三)重载保存的模型
现在我们重新载入已保存的模型文件,要注意一点,恢复模型时的数据流图结构要和保存时的相匹配才可以,不然无法恢复,下面我们来取回保存的数据:
import tensorflow as tf
# 创建一个TensorFlow变量
data = tf.Variable(tf.truncated_normal([2, 2]), name = 'data')
# 创建一个Saver,默认保存所有变量
saver = tf.train.Saver()
with tf.Session() as sess:
# 注意此处不再需要初始化
saver.restore(sess, 'C:/Users/12394/PycharmProjects/Spyder/test/model.ckpt')
print(sess.run(data))
# 输出:
INFO:tensorflow:Restoring parameters from C:/Users/12394/PycharmProjects/Spyder/test/model.ckpt
[[-0.03797581 -0.1358492 ]
[ 0.08680686 0.6192091 ]]
可以发现取回的数据与我们存储的数据相同,说明模型复用成功。