tf.nn.in_top_k()函数的参数如下:
in_top_k(predictions, targets, k, name=None)
- 1
predictions:预测的结果,预测矩阵大小为样本数×标注的label类的个数的二维矩阵。
targets:实际的标签,大小为样本数。
k:每个样本的预测结果的前k个最大的数里面是否包含targets预测中的标签,一般都是取1,即取预测最大概率的索引与标签对比。
name:名字。
假设有10个样本,标注为5类,10个样本实际标签均是第一类,代码如下:
import tensorflow as tf
logits = tf.Variable(tf.truncated_normal(shape=[10,5],stddev=1.0))
labels = tf.constant([0,0,0,0,0,0,0,0,0,0])
top_1_op = tf.nn.in_top_k(logits,labels,1)
top_2_op = tf.nn.in_top_k(logits,labels,2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(logits.eval())
print(labels.eval())
print(top_1_op.eval())
print(top_2_op.eval())
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
运行结果如下:
[[-0.01835343 -1.68495178 -0.67901242 -0.20486258 -0.22725371]
[ 1.84425163 -1.25509632 0.07132829 -1.81082523 -0.44123012]
[-0.4354656 0.1805554 0.81912154 0.04202025 -1.99823892]
[ 0.53393573 0.91522688 -1.88455033 -0.44571343 0.07805539]
[ 0.01253182 0.16593859 0.0918197 0.8079409 0.13442524]
[ 0.08205117 -0.26857412 0.02542082 0.38249066 -0.01555154]
[-1.02280331 0.18952899 0.49389341 0.58559865 0.80859423]
[ 0.35019293 -1.17765355 0.66553122 1.91787696 0.5998978 ]
[ 0.81723028 0.92895705 0.86031818 1.57651412 0.94040418]
[-0.83766556 -1.75260925 0.13499574 -0.06683849 -0.99427927]]
[0 0 0 0 0 0 0 0 0 0]
[ True True False False False False False False False False]
[ True True False True False True False False False False]
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
top_1_op为True的地方top_2_op一定为True,top_1_op取样本的最大预测概率的索引与实际标签对比,top_2_op取样本的最大和仅次最大的两个预测概率与实际标签对比,如果实际标签在其中则为True,否则为False。其他k的取值可以类推。