接上篇文章:《tensorflow2 保存SavedModel模型》。
概述
- tensorflow版本:2.4.1
- 目标:读取使用tf.saved_model.save(...)函数保存的模型,并使用该模型进行预测;
- api:tf.saved_model.load(...)可读取模型路径,并返回保存的模型,假设命名为model;
- 使用方法:直接使用 model( x_data ) 对输入数据进行预测,注:预测数据格式应该与训练模型时数据一致;
代码
import tensorflow as tf
# 模型保存路径
model_path = "E:/test/java_tf2_model/py/tf241_model"
# 加载模型
model = tf.saved_model.load(model_path)
# 使用模型进行预测
x_test = tf.reshape(tf.constant(
[10, 20, 30], dtype=tf.float32, name="inputs"), (-1, 1))
print(model(x_test))
结果
可以看到,本次使用[10,20,30]进行预测,得到的结果y=[24.999,44.999,65],符合预期结果(y=2x + 5)。
备注
- 加载模型需使用绝对路径;
- 使用加载的模型预测数据,输入数据格式应该与训练数据格式相同,否则无法进行预测;