对于tf.argmax,这个函数有点奇怪,axis=0指的是计算矩阵每列的最大值索引,axis=1计算行最大值索引
与numpy 相同
import tensorflow as tf
import numpy as np
a=np.array([[2,4,5,7],[9,3,6,2]])
print('-'*30+'分割线'+'-'*30)
print(a)
print('-'*30+'分割线'+'-'*30)
a1=tf.argmax(a,axis=0)
print('tf.argmax(a,axis=0)=',a1)
print('-'*30+'分割线'+'-'*30)
a1=np.argmax(a,axis=0)
print('np.argmax(a,axis=0)=',a1)
print('-'*30+'分割线'+'-'*30)
a1=tf.argmax(a,axis=1)
print('tf.argmax(a,axis=1)=',a1)
print('-'*30+'分割线'+'-'*30)
a1=np.argmax(a,axis=1)
print('np.argmax(a,axis=1)=',a1)
print('-'*30+'分割线'+'-'*30)
------------------------------分割线------------------------------
[[2 4 5 7]
[9 3 6 2]]
------------------------------分割线------------------------------
tf.argmax(a,axis=0)= tf.Tensor([1 0 1 0], shape=(4,), dtype=int64)
------------------------------分割线------------------------------
np.argmax(a,axis=0)= [1 0 1 0]
------------------------------分割线------------------------------
tf.argmax(a,axis=1)= tf.Tensor([3 0], shape=(2,), dtype=int64)
------------------------------分割线------------------------------
np.argmax(a,axis=1)= [3 0]
------------------------------分割线------------------------------