保存检查点(checkpoint)
艾伯特(http://www.aibbt.com/)国内第一家人工智能门户
为了得到可以用来后续恢复模型以进一步训练或评估的检查点文件(checkpoint file),我们实例化一个tf.train.Saver
。
saver = tf.train.Saver()
在训练循环中,将定期调用saver.save()
方法,向训练文件夹中写入包含了当前所有可训练变量值得检查点文件。
saver.save(sess, FLAGS.train_dir, global_step=step)
这样,我们以后就可以使用saver.restore()
方法,重载模型的参数,继续训练。
saver.restore(sess, FLAGS.train_dir)
评估模型
每隔一千个训练步骤,我们的代码会尝试使用训练数据集与测试数据集,对模型进行评估。do_eval
函数会被调用三次,分别使用训练数据集、验证数据集合测试数据集。
print 'Training Data Eval:'
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.train)
print 'Validation Data Eval:'
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation)
print 'Test Data Eval:'
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test)
注意,更复杂的使用场景通常是,先隔绝
data_sets.test
测试数据集,只有在大量的超参数优化调整(hyperparameter tuning)之后才进行检查。但是,由于MNIST问题比较简单,我们在这里一次性评估所有的数据。
构建评估图表(Eval Graph)
在打开默认图表(Graph)之前,我们应该先调用get_data(train=False)
函数,抓取测试数据集。
test_all_images, test_all_labels = get_data(train=False)
在进入训练循环之前,我们应该先调用mnist.py
文件中的evaluation
函数,传入的logits和标签参数要与loss
函数的一致。这样做事为了先构建Eval操作。
eval_correct = mnist.evaluation(logits, labels_placeholder)
evaluation
函数会生成tf.nn.in_top_k
操作,如果在K个最有可能的预测中可以发现真的标签,那么这个操作就会将模型输出标记为正确。在本文中,我们把K的值设置为1,也就是只有在预测是真的标签时,才判定它是正确的。
eval_correct = tf.nn.in_top_k(logits, labels, 1)
评估图表的输出(Eval Output)
之后,我们可以创建一个循环,往其中添加feed_dict
,并在调用sess.run()
函数时传入eval_correct
操作,目的就是用给定的数据集评估模型。
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
true_count
变量会累加所有in_top_k
操作判定为正确的预测之和。接下来,只需要将正确测试的总数,除以例子总数,就可以得出准确率了。
precision = float(true_count) / float(num_examples)
print ' Num examples: %d Num correct: %d Precision @ 1: %0.02f' % (
num_examples, true_count, precision)