之前的文章中讲了如何使用tensorflow源码编译一个c++版的动态库。同时留下了一个问题:能否在C++中读取预先训练好的模型呢?———答案是肯定的。
下面,就来一一介绍tensorflow模型在python中的存储和读取,在c++中的读取方式。为什么不讲如何用C++去存储一个模型呢?因为不建议大家用c++训练模型,其中的原因有三点:

其一,基本上99%的tensorflow神经网络都是用python写的,如果你想照抄一个网络,用python最方便。
其二,python有强大的第三方库,高级的语法特性,这些c++上实现需要花费巨大精力。
其三,tensorflow的C++接口支持的并不好。

一、python存储模型的方法

好了,进入正题,在python中如何存储tensorflow模型。

  1. tf.saved_model.builder(推荐
    tf.saved_model是tensorflow官网推荐的一个保存模型的方法,只要你输入保存模型的路径,就可以使用。基本使用方式如下:
import tensorflow as tf

input=...
export_dir=...
...
build net...
...
 #指定存储路径
builder =tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session() as sess:
    #下段话只能调用一次
    builder.add_meta_graph_and_variables(sess,['custom'])
builder.save()

其中,export_dir必须指定为一个不存在的路径,否则会报错。上面一段代码中,我们建立了一个名叫’custom’的网络,并将其保存在export_dir中,文件结构如下:

|-saved_model.pb-|
|-variables-|
    |-variables.data-00000-of-00001-|
    |-variables.index-|

其实和下一个方法tf.save存出来的文件差不多。pb文件中是网络结构信息,index文件中是参数值。

  1. tf.train.saver
    tf.train.saver是1.3版本之前主要的模型存储方式,在新版本中也兼容,但已经不是最推荐的方式了。它的使用方式也很简单:
import tensorflow as tf

input=...
model_path=...
...
build net
...
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.save(sess,model_path)

tf.save.saver产生的文件结构如下:

|-saved_model.meta-|
|-saved_model.data-00000-of-00001-|
|-saved_model.index-|

meta文件中存储网络结构,index文件中存储参数信息。

  1. tf.saved_model.builder和tf.train.saver方法比较
    tf.saved_model.builder方法:
    优点:
    1.只需要指定一个存储路径。存储、读取都很方便。
    2.可以存多段网络,参数可以复用。比如现在有一个GAN网络模型,用tf.saved_model.builder指定相应tag以后,可以同时存生成网络、鉴别网络和整个网络。之后读取时,只要读需要的那一部分即可,大大加快读取速度。提升内存利用率。
    3.在tensorflow推荐的estimate(一种更高级的机器学习API,以后填坑)流程中,扮演主要的模型存储方法。
    4.便于分布式读取及使用
    缺点:
    1.只能保存一次参数
    2.对于一个目录,只能导出一个模型。(但可以改变目录名)
    3.不灵活。
    4.速度慢。

  1. tf.train.saver方法:
    优点:
    1.灵活。可以指定保存模型的名称、后缀、多长时间保存一次、最多保存多少个模型等等。
    2.应用范围广。如果你使用tf.contrib.Slim库(类似tensorlayer的一种高级库)训练模型,那么只能用此方法保存模型。
    3.速度快。
    缺点:
    1.保存多个模型比较复杂。

二、python读取模型的方法

tensorflow读取模型的方法也很简单。我对应的介绍一下。

  1. tf.saved_model.loader
    如果你使用tf.saved_model.builder存储模型的话,那么可以使用tf.saved_model.loader读取模型。只输入一个模型存储的路径即可。简单的例子:
export_dir = ...
...
build net...
...
with tf.Session(graph=tf.Graph()) as sess:
  tf.saved_model.loader.load(sess, ['custom'], export_dir)
  ...

可以看到,该方式读取模型非常简单,只需要模型路径和网络标签即可,函数内部会自动加载网络模型和恢复参数。

  1. tf.train.saver.restore
    该方法需要先恢复网络结构(如果你有了定义网络的py文件,可以跳过此步,等价的),再读取参数。简单的例子:
model_path=...
 #恢复网络结构
saver = tf.train.import_meta_graph(model_path + '.meta')
with tf.Session() as sess:
    #读取参数
    saver.restore(sess, model_path)
    graph = sess.graph
    input = graph.get_tensor_by_name('input:0')
    ...
    prediction...
    ...

pythond的模型存取方式就介绍到这里,更多有关tf.train.save和tf.saved_model的区别请点这里。


c++读取模型的方式

此章将会辅助一些截图说明。原因是相对于python,tensorflow的c++接口的有点烂。一开始也许你会卡在某一步骤,但是耐心的一步步排查,终将能成功。

  1. LoadSavedModel(对应tf.save_model.builder方式)
    先上代码
#include <string>
 #include <cc/saved_model/loader.h>
 #include <google/protobuf/message.h>

tensorflow::Status LoadGraph(std::string modelDir, std::unique_ptr<tensorflow::Session>* sess) {
    //定义初始环境
    const std::string export_dir = modelDir;
    tensorflow::SessionOptions session_options;
    tensorflow::RunOptions run_options;
    tensorflow::SavedModelBundle bundle;
    tensorflow::Status status;
    constexpr char kSavedModelTagServe[] = "train";
    //存储模型
    status=LoadSavedModel(session_options, run_options, export_dir, { kSavedModelTagServe },&bundle);
    if (!status.ok()) {
        std::cerr << "Error reading graph definition from " + modelDir+ ": " + status.ToString() << std::endl;
        return status;
    }
    *sess =std::move(bundle.session);
    return status;
};

其中,modelDir是模型目录,sess是载入的图模型环境。运行完LoadSavedModel方法后,你得到的status状态应该是空的,bundle中应该已经有内容了:

tensorflow用训练好的模型预测数据_深度学习模型


如果发生错误,status里会有相应的问题描述,可以根据它尝试解决一下问题。

  1. LoadSavedModel(对应tf.save_model.builder方式)
    简单例子如下:
#include <string>
 #include <cc/framework/scope.h>
 #include <core/public/session.h>
 #include <core/protobuf/meta_graph.pb.h>
 using namespace tensorflow;

tensorflow::Status LoadGraph(std::string checkpointPath, std::unique_ptr<tensorflow::Session>* sess) {
    string metaGraphPath= checkpointPath + ".meta";
    if (*sess == nullptr) {
        (*sess).reset(tensorflow::NewSession(tensorflow::SessionOptions()));
    }
    Status status;
    auto scp = ::Scope::NewRootScope();
    // 读网络
    tensorflow::MetaGraphDef graph_def;
    status = ReadBinaryProto(Env::Default(), metaGraphPath, &graph_def);
    if (!status.ok()) {
        std::cerr << "Error reading graph definition from " + metaGraphPath+ ": " + status.ToString() << std::endl;
        return status;
    }
    // 将网络加入sess中
    status = (*sess)->Create(graph_def.graph_def());
    if (!status.ok()) {
        std::cerr << "Error creating graph: " + status.ToString() << std::endl;
    }
    // 读参数
    Tensor checkpointPathTensor(DT_STRING, TensorShape());
    checkpointPathTensor.scalar<std::string>()() = checkpointPath;
    status = (*sess)->Run(
    { { graph_def.saver_def().filename_tensor_name(), checkpointPathTensor }, },
    {},
    { graph_def.saver_def().restore_op_name() },
        nullptr);
    if (!status.ok()) {
        std::cerr << "Error loading checkpoint from " + checkpointPath + ": " + status.ToString() << std::endl;
    }
    return status;
};

其中,checkpointPath是模型路径,sess是载入的图模型环境。运行完后,你得到的status状态应该是空的,graph_def.meta_info_def中有值,如下:

tensorflow用训练好的模型预测数据_python_02


如果错误,status会返回错误方法,可以根据描述修复问题。


好了,tensorflow模型存取的方法介绍到这里。接口中的一些参数我没有仔细讲,需要深入研究的童鞋可以去tensorflow官网看一下参数介绍。如果你掌握了本章所说的方法,那么基本上tensorflow的应用已经不成问题了。以后我会多讲一讲具体的tensorflow预测网络应用。

最后祝您身体健康,再见!