Tensorflow学习笔记(二)模型的保存与加载(一 )
- SavedModel模型的保存与加载
- 保存
- 加载
- 查看模型的Signature签名
SavedModel模型的保存与加载
声明: 参考链接这篇博文以及官方文档
保存
关于SavedModel模型的好处与简介大家可以参考百度,本文只用一个很简单的例子来演示SavedModel模型的保存与加载。
在上一篇的末尾我们贴出了一个很简单的代码,它具有一个输入以及一个输出。其实这就算是一个最最简单的tensorflow模型了,它实现了 (y = x + b)当输入一个x 那么输出的结果y就等于输入x加上b。这已经算的上是一个模型了当然它非常的简单,这次我们就把这个简单模型保存起来并实现加载调用把!
先看下这个简单的模型
import tensorflow as tf # 以下所有代码默认导入
# 创建一个变量
one = tf.Variable(2.0)
# 创建一个占位符,在Tensorflow中需要定义placeholder的type,一般为 float32形式
num = tf.placeholder(tf.float32,name='input')
# 创建一个加法步骤,注意这里并没有直接计算
sum = tf.add(num,one,name='output')
# 初始化变量,如果定义Variable就必须初始化
init = tf.global_variables_initializer()
# 创建会话sess
with tf.Session() as sess:
sess.run(init)
print(sess.run(sum,feed_dict={num:5.0}))
运行结果如下
7.0
SavedModel模型的保存其实非常简单仅仅需要几行代码就能完成
如下:
# #保存SavedModel模型
builder = tf.saved_model.builder.SavedModelBuilder('./models')
signature = predict_signature_def(inputs={'input':num}, outputs={'output':sum})
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING],signature_def_map={'predict': signature})
builder.save()
先构建 SavedModelBuilder 类的对象,使用tf.saved_model.builder.SavedModelBuilder方法,该方法的参数是传入用于保存模型的目录名,目录不用预先创建。
然后 生成签名,签名是一组与图有关的输入和输出。使用predict_signature_def方法,传入的参数为输入和输出以及他们的name。
接着传入graph(图)和Variables(变量)给add_meta_graph_and_variables方法。
第一个参数传入的是Session它包含了当前graph(图)和Variables(变量)。
第二个参数是给当前需要保存的MetaGraph 一个标签,标签名可以自定义,在之后载入模型的时候,需要根据这个标签名去查找对应的MetaGraphDef,找不到就会报如 RuntimeError: MetaGraphDef associated with tags ‘foo’ could not be found in SavedModel这样的错。
标签也可以选用系统定义好的参数,
tf.saved_model.tag_constants.SERVING与
tf.saved_model.tag_constants.TRAINING等。
完整代码
import tensorflow as tf # 以下所有代码默认导入
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
# 保存模型路径
PATH = './models'
# 创建一个变量
one = tf.Variable(2.0)
# 创建一个占位符,在 Tensorflow 中需要定义 placeholder 的 type ,一般为 float32 形式
num = tf.placeholder(tf.float32,name='input')
# 创建一个加法步骤,注意这里并没有直接计算
sum = tf.add(num,one,name='output')
# 初始化变量,如果定义Variable就必须初始化
init = tf.global_variables_initializer()
# 创建会话sess
with tf.Session() as sess:
sess.run(init)
# #保存SavedModel模型
builder = tf.saved_model.builder.SavedModelBuilder(PATH)
signature = predict_signature_def(inputs={'input':num}, outputs={'output':sum})
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING],signature_def_map={'predict': signature})
builder.save()
执行完成后会在当前项目的目录下生成models文件夹,里面包含variables
文件夹以及saved_model.pb
文件。variables
保存所有变量信息,saved_model.pb
用于保存模型结构等信息。
注意:当前目录下不可以存在models
文件夹,否则会报如下错误!!
训练一次后会自动生成models
文件夹,下次训练前记得删除或者换个名字
加载
加载模型的话更为简单只需要调用tf.saved_model.loader.load方法就可以载入模型
import tensorflow as tf # 以下所有代码默认导入
PATH = './models'
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ["serve"], PATH)
in_x = sess.graph.get_tensor_by_name('input:0') #加载输入变量
y = sess.graph.get_tensor_by_name('output:0') #加载输出变量
scores = sess.run(y, feed_dict={in_x: 3.})
print(scores)
使用tf.saved_model.loader.load方法加载模型,第二个参数为TAG标签在save模型时定义,第三个参数为模型路径*
使用sess.graph.get_tensor_by_name方法加载输入输出变量,注意这里的变量name都需要加上“:0”,如"input"变为"input:0"
最后像之前那样sess.run(),feed喂入数据,这里输入了个3.0。
结果
5.0
至此SavedModel模型的保存与加载就已经全部完成!
查看模型的Signature签名
值得注意的是这种方法加载模型是需要知道标签的。如果我们拿到了别人的一个SavedModel模型而且并不知道“标签”那么怎么调用呢?
别慌Tensorflow官方已经为我们准备好了一个脚本,如果你已经安装了Tensorflow那么你可以在./tensorflow/python/tools/目录下找到saved_model_cli.py文件。
我们可以’WIN+R‘输入’cmd‘然后回车打开你的CMD,然后指定路径到你的模型目录下
如
然后输入saved_model_cli show --dir=./ --all
等待一会
在CMD打印出的信息中我们就可以看到模型的输入/输出的名称、数据类型、shape以及方法名称。
希望这篇文章对您有帮助,感谢阅读!