TensorFlow应用——tf.set_random_seed 的用法
一、会话级种子:seed
当在代码中使用了随机数,但是希望代码在不同时间或者不同的机器上运行能够得到相同的随机数,以至于能够得到相同的运行结果,那么就需要设置随机函数的seed 参数,对应的变量可以跨会话(session)生成相同的随机数。
陈瑞一首《白狐》,讲述一段跨越千年的人妖恋,听一遍哭一遍!
例1:
#作者:文方俊
#日期: 2021年05月27日
#功能:会话级种子seed实现
#import tensorflow as tf # TensorFlow版本号<2.0
import tensorflow.compat.v1 as tf # TensorFlow版本号>=2.0
tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
a= tf.random_normal([1],mean=0, stddev=1, seed=1)
b= tf.random_normal([1],mean=0,stddev=1)
print('Session1')
tf.print(a) # TensorFlow版本号>=2.0
tf.print(a)
tf.print(b)
tf.print(b)
print('Session2')
tf.print(a)
tf.print(a)
tf.print(b)
tf.print(b)
结果:
Session1
[1.84709585]
[1.84709585]
[-0.179822043]
[-0.179822043]
Session2
[1.84709585]
[1.84709585]
[-0.179822043]
[-0.179822043]
可以看出在TensorFlow 2.0以后的版本中,a设置了seed=1之后,在不同的会话(session)中a产生的随机数是一致的,而b在不同的会话(session)中产生的随机数也是一致的。
二、图级种子:tf.set_random_seed
如果不想一个一个的设置随机种子seed,那么可以使用全局设置tf.set_random_seed()函数,使用之后后面设置的随机数都不需要设置seed,而可以跨会话生成相同的随机数。
例2:
#作者:文方俊
#日期:2021年05月27日
#功能:图级种子seed实现
import tensorflow.compat.v1 as tf # TensorFlow版本号>=2.0
tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
tf.set_random_seed(1)#设置全局随机种子
a= tf.random_normal([1],mean=0, stddev=1)
b= tf.random_normal([1],mean=0,stddev=1)
print('Session1')
tf.print(a) # TensorFlow版本号>=2.0
tf.print(a) # TensorFlow版本号>=2.0
tf.print(b) # TensorFlow版本号>=2.0
tf.print(b) # TensorFlow版本号>=2.0
print('Session2')
tf.print(a) # TensorFlow版本号>=2.0
tf.print(a) # TensorFlow版本号>=2.0
tf.print(b) # TensorFlow版本号>=2.0
tf.print(b) # TensorFlow版本号>=2.0
结果:
Session1
[-1.10122025]
[-1.10122025]
[0.403087884]
[0.403087884]
Session2
[-1.10122025]
[-1.10122025]
[0.403087884]
[0.403087884]
上面例子我们会发现,在TensorFlow>=2.0的版本中,如果设置了图级种子,同一个变量在同一个或不同的会话当中,产生的随机数都一致。
三、会话级种子中,定义两个变量的随机生成函数一样,种子一样,结果不一样
例3:
#作者:文方俊
#日期:2021年05月27日
#功能:会话级种子,种子一样,结果不一样
import tensorflow.compat.v1 as tf # TensorFlow版本号>=2.0
tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
a= tf.random_normal([1],mean=0, stddev=1,seed=2)
b= tf.random_normal([1],mean=0,stddev=1,seed=2)
print('a')
tf.print(a) # TensorFlow版本号>=2.0
tf.print(a) # TensorFlow版本号>=2.0
print('b')
tf.print(b) # TensorFlow版本号>=2.0
tf.print(b) # TensorFlow版本号>=2.0
结果:
a
[-0.164025202]
[-0.164025202]
b
[-1.21130085]
[-1.21130085]
由此可知,在TensorFlow版本号>=2.0的版本中,不同的变量在会话级种子设置成相同的情况下,随机生成函数生成的结果不一样。
四、会话级种子中,设置为变量variable,得到同一个会话(session)可复用的结果
例4:
#作者:文方俊
#日期:2021年05月27日
#功能:会话级种子,种子一样,结果不一样
import tensorflow.compat.v1 as tf # TensorFlow版本号>=2.0
tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
a= tf.Variable(tf.random_normal([1],mean=0, stddev=1,seed=2))
init_op=tf.global_variables_initializer()
print('a')
tf.print(a) # TensorFlow版本号>=2.0
tf.print(a) # TensorFlow版本号>=2.0
结果:
a
[0.453302264]
[0.453302264]
关注“AI早知道”公众号,
更多精彩,
下回待续,
。。。
遇见爱or遇见自己or遇见幸福,
再出发,
遇见更精彩的自己
。。。