接上篇文章:《​​tensorflow2 保存SavedModel模型​​》。

概述

  1. tensorflow版本:2.4.1
  2. 目标:读取使用tf.saved_model.save(...)函数保存的模型,并使用该模型进行预测;
  3. api:tf.saved_model.load(...)可读取模型路径,并返回保存的模型,假设命名为model;
  4. 使用方法:直接使用 model( x_data ) 对输入数据进行预测,注:预测数据格式应该与训练模型时数据一致;

代码

import tensorflow as tf

# 模型保存路径
model_path = "E:/test/java_tf2_model/py/tf241_model"

# 加载模型
model = tf.saved_model.load(model_path)

# 使用模型进行预测
x_test = tf.reshape(tf.constant(
[10, 20, 30], dtype=tf.float32, name="inputs"), (-1, 1))
print(model(x_test))

结果

tensorflow2 读取SavedModel模型_SavedModel

可以看到,本次使用[10,20,30]进行预测,得到的结果y=[24.999,44.999,65],符合预期结果(y=2x + 5)。


备注

  1. 加载模型需使用绝对路径;
  2. 使用加载的模型预测数据,输入数据格式应该与训练数据格式相同,否则无法进行预测;