tf.train.Saver类负责保存和还原神经网络
自动保存为三个文件:模型文件列表checkpoint,计算图结构model.ckpt.meta,每个变量的取值model.ckpt
checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.
model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据
model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表
tensorflow 中导出/恢复模型Graph数据Saver
Is there an example on how to generate protobuf files holding trained Tensorflow graphs
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/README.md
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
https://github.com/anandanand84/tensorflow_model_server
export
import tensorflow as tf
from tensorflow.python.platform import gfile
# 这是从二进制格式的pb文件加载模型
graph = tf.get_default_graph()
graphdef = graph.as_graph_def()
graphdef.ParseFromString(gfile.FastGFile("/data/TensorFlowAndroidMNIST/app/src/main/expert-graph.pb", "rb").read())
_ = tf.import_graph_def(graphdef, name="")
#这是从meta文件加载模型
_ = tf.train.import_meta_graph("model.ckpt.meta")
summary_write = tf.summary.FileWriter("/data/TensorFlowAndroidMNIST/logdir" , graph)
freeze_graph
tensorflow,使用freeze_graph.py将模型文件和权重数据整合在一起并去除无关的Op
tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
Tensorflow 训练模型数据freeze固话保存在Graph中
Steps to reproduce freeze_graph
tf.train.write_graph
tf.train.write_graph()保存模型,它只是保存了模型的结构,并不保存训练完毕的参数值
tf.train.saver()保存模型,将网络中的参数值与模型的结构分开存储
tf.train.Saver函数保存模型文件的时候,是保存所有的参数信息,而有些时候我们并不需要所有的参数信息。我们只需要知道神经网络的输入层经过前向传播计算得到输出层即可,所以在保存的时候,我们也不需要保存所有的参数,以及变量的初始化、模型保存等辅助节点信息与迁移学习类似。之前使用tf.train.Saver函数保存模型文件的时候会产生多个文件,它将变量的取值和计算图结构分成了不同的文件存储。TensorFlow提供了另一种保存模型文件的方法,将计算图保存在一个文件中
写入pb文件
TensorFlow的convert_variables_to_constants函数
如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件
graph_util.convert_variables_to_constants可以把整个sesion当作常量都保存下来,通过output_node_names参数来指定输出
tf.gfile.FastGFile('model/cxq.pb', mode='wb')指定保存文件的路径以及读写方式
f.write(output_graph_def.SerializeToString())将固化的模型写入到文件
模型保存
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
if __name__ == "__main__":
a = tf.Variable(tf.constant(5.,shape=[1]),name="a")
b = tf.Variable(tf.constant(6.,shape=[1]),name="b")
c = a + b
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
#导出当前计算图的GraphDef部分
graph_def = tf.get_default_graph().as_graph_def()
#保存指定的节点,并将节点值保存为常数
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])
#将计算图写入到模型文件中
model_f = tf.gfile.GFile("model.pb","wb")
model_f.write(output_graph_def.SerializeToString())
convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存。在保存模型文件的时候,我们只是导出了GraphDef部分,GraphDef保存了从输入层到输出层的计算过程。在保存的时候,通过convert_variables_to_constants函数来指定保存的节点名称而不是张量的名称,“add:0”是张量的名称而"add"表示的是节点的名称
模型读取
sess = tf.Session()
#将保存的模型文件解析为GraphDef
model_f = gfile.FastGFile("model.pb",'rb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_f.read())
c = tf.import_graph_def(graph_def,return_elements=["add:0"])
print(sess.run(c))
#[array([ 11.], dtype=float32)]
在读取模型文件获取变量的值的时候,我们需要指定的是张量的名称而不是节点的名称