TensorFlow实践(13)——保存和复用训练好的模型

  • (一)前 言
  • (二)保存训练好的模型
  • (三)重载保存的模型
  • (四)总结


(一)前 言

当模型训练完成之后,我们可以使用tf.train.Saver()方法将训练好的模型进行保存,以便于之后使用模型进行预测等任务,而不用重复训练。Saver构造方法的主要输入参数:

参数名称

功能说明

默认值

var_list

Saver存储的变量集合

全局变量集合

reshape

是否允许从checkpoint文件中恢复时改变变量形状

True

sharded

是否将checkpoint文件中的变量轮循放置在所有设备上

True

max_to_keep

保留最近的检查点的个数

5

restore_sequentially

是否按顺序恢复所有变量,当模型较大时顺序回复可以降低内存

True

(二)保存训练好的模型

接下来我们通过一个简单的例子演示如何保存训练好的模型:

import tensorflow as tf

# 创建一个TensorFlow变量
data = tf.Variable(tf.truncated_normal([2, 2]), name = 'data')
# 创建一个Saver,默认保存所有变量
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(data))
    saver.save(sess, 'C:/Users/12394/PycharmProjects/Spyder/test/model.ckpt')
# 输出:
[[-0.03797581 -0.1358492 ]
 [ 0.08680686  0.6192091 ]]

在对应的目录下,我们可以找到如下文件:

python 如何使用tensorflow模型训练 tensorflow训练好的模型_tensorflow


model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

(三)重载保存的模型

现在我们重新载入已保存的模型文件,要注意一点,恢复模型时的数据流图结构要和保存时的相匹配才可以,不然无法恢复,下面我们来取回保存的数据:

import tensorflow as tf

# 创建一个TensorFlow变量
data = tf.Variable(tf.truncated_normal([2, 2]), name = 'data')
# 创建一个Saver,默认保存所有变量
saver = tf.train.Saver()

with tf.Session() as sess:
    # 注意此处不再需要初始化
    saver.restore(sess, 'C:/Users/12394/PycharmProjects/Spyder/test/model.ckpt')
    print(sess.run(data))
# 输出:
INFO:tensorflow:Restoring parameters from C:/Users/12394/PycharmProjects/Spyder/test/model.ckpt
[[-0.03797581 -0.1358492 ]
 [ 0.08680686  0.6192091 ]]

可以发现取回的数据与我们存储的数据相同,说明模型复用成功。