tf.add_to_collection() 和tf.get_collection()语句的组合使用旨在更好的管理同一类型(或者意义)的张量。
tf.add_to_collection():把张量发到一起,并用同一个命名空间命名多个张量,将多个张量组合成一个list,没有返回值。
tf.get_collection(name) :将之前通过tf.add_to_collection()语句添加的张量集合,通过name参数提取出来,返回一个list。
代码解析
#_*_coding:utf-8_*_
import tensorflow as tf
value1=tf.get_variable(name = 'value1',shape=[3],initializer=tf.ones_initializer())
value2=tf.get_variable(name = 'value2',shape=[3],initializer=tf.random_uniform_initializer(maxval=-1,minval=1,seed=0))
#第二个类别的tensor
loss1 = tf.get_variable(name = 'loss1',shape = [1],initializer=tf.constant_initializer(0))
loss2 = tf.get_variable(name = 'loss2',shape = [1],initializer=tf.constant_initializer(0))
#利用tf.add_to_collection()管理上述两个类别的tensor(张量)
tf.add_to_collection('value',value1)
tf.add_to_collection('value',value2)
tf.add_to_collection('loss',loss1)
tf.add_to_collection('loss',loss2)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
#利用tf.get_collection(name)来调取上述存入的两个tensor
value = tf.get_collection(value)
#可以利用eval函数来输出两个tensor的值
print value
print value[0].eval
print value[1].eval
#利用loss这个name来管理这两个变量
#利用tf.add_n这个函数两统计collection中的tensor数量
loss = tf.get_collection(loss)#返回的就是个列表
total_collection_num = tf.add_n(loss)
print loss
print total_collection_num