参考:
- TensorFlow模型保存和提取方法
1、基本用法
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# import glob
import numpy as np
logdir='./output/'
with tf.variable_scope('conv'):
w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
sess=tf.InteractiveSession()
saver=tf.train.Saver([w]) # 参数为空,默认保存所有变量,此处只保存变量w
tf.global_variables_initializer().run() # 初始化所有变量
# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
tf.variables_initializer([b]) # 初始化变量b
saver.save(sess,logdir+'model.ckpt')
print('w',w.eval())
print('-----------')
print('b',b.eval())
sess.close()
2、变量从ckpt中提取,没有的需初始化
- 第一运行
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse
logdir='./output/'
with tf.variable_scope('conv'):
w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
# with tf.variable_scope('conv2'):
# w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
# b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)
sess=tf.InteractiveSession()
tf.global_variables_initializer().run() # 初始化所有变量
# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
try:
saver = tf.train.Saver() # 参数为空,默认保存所有变量,这里只有变量w1、b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
except:
saver = tf.train.Saver([w1,b1]) # 参数为空,默认保存所有变量,这里只有变量w1、b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver = None
# tf.variables_initializer([b1]) # 初始化变量b
saver=tf.train.Saver() # 参数为空,默认保存所有变量,这里只有变量w1、b1
saver.save(sess,logdir+'model.ckpt')
print('w',w1.eval())
print('-----------')
print('b',b1.eval())
print('-----------')
# print('w',w2.eval())
# print('-----------')
# print('b',b2.eval())
sess.close()
- 第二次运行
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse
logdir='./output/'
with tf.variable_scope('conv'):
w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
with tf.variable_scope('conv2'):
w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)
sess=tf.InteractiveSession()
tf.global_variables_initializer().run() # 初始化所有变量
# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
try:
saver = tf.train.Saver() # 参数为空,默认提取所有变量,
saver.restore(sess, ckpt.model_checkpoint_path)
saver = None
except:
saver = tf.train.Saver([w1, b1]) # 参数为空,默认提取所有变量,
# 此处提取变量w1、b1(因为上步保存的变量没有w2,b2,如果使用saver = tf.train.Saver()会报错)
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
# tf.variables_initializer([b1]) # 初始化变量b
saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2
saver.save(sess,logdir+'model.ckpt')
print('w',w1.eval())
print('-----------')
print('b',b1.eval())
print('-----------')
print('w',w2.eval())
print('-----------')
print('b',b2.eval())
sess.close()
- 第3次运行
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse
logdir='./output/'
with tf.variable_scope('conv'):
w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
with tf.variable_scope('conv2'):
w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)
sess=tf.InteractiveSession()
tf.global_variables_initializer().run() # 初始化所有变量
# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
# try:
# saver = tf.train.Saver() # 参数为空,默认提取所有变量,
# saver.restore(sess, ckpt.model_checkpoint_path)
# saver = None
# except:
saver = tf.train.Saver([w1, b1]) # 上一步保存的变量有w1,b1,w2,b2,这里只提取w1,b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
# tf.variables_initializer([w2,b2]) # 初始化变量w2,b2
saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2
saver.save(sess,logdir+'model.ckpt')
print('w',w1.eval())
print('-----------')
print('b',b1.eval())
print('-----------')
print('w',w2.eval())
print('-----------')
print('b',b2.eval())
sess.close()
3、总结
- 保存变量
tf.global_variables_initializer().run() # 初始化所有变量
saver=tf.train.Saver() # 参数为空,默认保存所有变量
saver=tf.train.Saver([w,b]) # 保存部分变量
saver.save(sess,logdir+'model.ckpt')
- 导入变量
tf.global_variables_initializer().run() # 初始化所有变量
# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
try:
saver = tf.train.Saver() # 参数为空,默认导入所有变量,
saver.restore(sess, ckpt.model_checkpoint_path)
saver = None
except:
saver = tf.train.Saver([w1, b1]) # 导入部分变量
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
如果保存的变量有w1,b1,w2,b2,但只导入w1,b1,对w2,b2重新初始化,训练等 使用
saver = tf.train.Saver([w1, b1])
如果保存的变量中只有w1,b1,现在新增变量w2,b2 则只能导入w1,b1
saver = tf.train.Saver([w1, b1])
通过这种方法可以实现,只导入模型的前n-1层参数,而对第n层参数重新初始化训练,这样就能很好的实现迁移学习
4、补充 结合variable_scope
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse
logdir='./output/'
with tf.variable_scope('conv'):
w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
# with tf.variable_scope('conv2'):
# w=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
# b=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)
sess=tf.InteractiveSession()
tf.global_variables_initializer().run() # 初始化所有变量
# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
# try:
# saver = tf.train.Saver() # 参数为空,默认提取所有变量,
# saver.restore(sess, ckpt.model_checkpoint_path)
# saver = None
# except:
saver = tf.train.Saver([tf.variable_op_scope(w,name_or_scope='conv/w:0').args[0],
tf.variable_op_scope(b, name_or_scope='conv/b:0').args[0]]) # 上一步保存的变量有w1,b1,w2,b2,这里只提取w1,b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
# tf.variables_initializer([w2,b2]) # 初始化变量w2,b2
saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2
saver.save(sess,logdir+'model.ckpt')
print('w',w.eval())
print('-----------')
print('b',b.eval())
print('-----------')
# print('w',w2.eval())
# print('-----------')
# print('b',b2.eval())
sess.close()
说明:
with tf.variable_scope('conv'):
w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
with tf.variable_scope('conv2'):
w=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)
print('w',w.eval()) # 打印的是'conv2' 中的w
print('-----------')
print('b',b.eval())# 打印的是'conv2' 中的b
如果要打印的是’conv1’ 中的w,b
print('w',tf.variable_op_scope(w,name_or_scope='conv/w:0').args[0].eval())
print('-----------')
print('b',tf.variable_op_scope(b, name_or_scope='conv/b:0').args[0].eval())
打印’conv2’ 中的w,b也可以使用
print('w',tf.variable_op_scope(w,name_or_scope='conv2/w:0').args[0].eval())
print('-----------')
print('b',tf.variable_op_scope(b, name_or_scope='conv2/b:0').args[0].eval())