TensorFlow模型和权重的保存

因为大肥狼在使用保存的模型和权重进行预测时遇到了一些问题,所以本文将介绍如何在TensorFlow中保存模型和权重,并如何使用保存的模型和权重来进行预测。

1.代码

我们的代码主要是吴恩达作业第二单元第三周-----tensorflow入门这一部分。

这部分比较简单
主要是使用tf.train.Saver()来保存神经网络的网络结构图和相关变量。这里我新建了一个checkpoint文件夹用来保存模型和权重文件。
使用的代码如下

saver = tf.train.Saver()

Tensorflow变量的作用范围是在一个session里面。在保存模型的时候,应该在session里面通过save方法保存。其中sess是会话名称,./checkpoint/model.ckpt是你保存路径下模型名字。

saved_path = saver.save(sess,'./checkpoint/model.ckpt' )

完整代码

def model(X_train, Y_train, X_test, Y_test, learning_rate = 0.0001,
          num_epochs = 1500, minibatch_size = 32, print_cost = True):
    ops.reset_default_graph()                         # to be able to rerun the model without overwriting tf variables
    tf.set_random_seed(1)                             # to keep consistent results
    seed = 3                                          # to keep consistent results
    (n_x, m) = X_train.shape                          # (n_x: input size, m : number of examples in the train set)
    n_y = Y_train.shape[0]                            # n_y : output size
    costs = []                                        # To keep track of the cost
    
    # Create Placeholders of shape (n_x, n_y)
    X, Y = create_placeholders(n_x, n_y)
    
    # Initialize parameters
    parameters = initialize_parameters()
    
    # Forward propagation: Build the forward propagation in the tensorflow graph
    Z3 = forward_propagation(X, parameters)
    tf.add_to_collection('pred_network', Z3)
    # Cost function: Add cost function to tensorflow graph
    cost = compute_cost(Z3, Y)
    
    # Backpropagation: Define the tensorflow optimizer. Use an AdamOptimizer.
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(cost)
    
    # Initialize all the variables
    init = tf.global_variables_initializer()

    saver = tf.train.Saver()  
    
    # Start the session to compute the tensorflow graph
    with tf.Session() as sess:
        
        # Run the initialization
        sess.run(init)
        
        # Do the training loop
        for epoch in range(num_epochs):

            epoch_cost = 0.                       # Defines a cost related to an epoch
            num_minibatches = int(m / minibatch_size) # number of minibatches of size minibatch_size in the train set
            seed = seed + 1
            minibatches = random_mini_batches(X_train, Y_train, minibatch_size, seed)

            for minibatch in minibatches:

                # Select a minibatch
                (minibatch_X, minibatch_Y) = minibatch
                
                # IMPORTANT: The line that runs the graph on a minibatch.
                # Run the session to execute the "optimizer" and the "cost", the feedict should contain a minibatch for (X,Y).
                _ , minibatch_cost = sess.run([optimizer, cost], feed_dict={X: minibatch_X, Y: minibatch_Y})
                
                epoch_cost += minibatch_cost / num_minibatches

            # Print the cost every epoch
            if print_cost == True and epoch % 100 == 0:
                print ("Cost after epoch %i: %f" % (epoch, epoch_cost))
                
            if print_cost == True and epoch % 5 == 0:
                costs.append(epoch_cost)
        #保存模型
		saved_path = saver.save(sess,'./checkpoint/model.ckpt' ) 
        
        # plot the cost
        plt.plot(np.squeeze(costs))
        plt.ylabel('cost')
        plt.xlabel('iterations (per tens)')
        plt.title("Learning rate =" + str(learning_rate))
        plt.show()

        # lets save the parameters in a variable
        parameters = sess.run(parameters)
        print ("Parameters have been trained!")

        # Calculate the correct predictions
        correct_prediction = tf.equal(tf.argmax(Z3), tf.argmax(Y))

        # Calculate accuracy on the test set
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

        print ("Train Accuracy:", accuracy.eval({X: X_train, Y: Y_train}))
        print ("Test Accuracy:", accuracy.eval({X: X_test, Y: Y_test}))
        
        return parameters

这里的例子比较复杂,包括你也可以设置每100代保存一次权重

保存完后,你的文件目录下的checkpoint 文件夹下会多出四个文件。如下图所示

CBOW模型为什么取权重 模型和权重_ci

model.meta为模型文件,model.data-00000-of-00001为数据文件。后面我们恢复模型需要用到。

使用模型和权重进行预测

接下来,我们讲预测时如何使用模型和权重进行预测。可以参考这个链接 ,写的很简洁。
1.导入保存好的模型结构,即model.meta文件

meta_path = './checkpoint/model.ckpt.meta'  
    saver = tf.train.import_meta_graph(meta_path) # 导入

2.加载的(变量)参数:
使用restore()方法恢复模型的变量参数。

data_path = './checkpoint/model.ckpt'  
    saver.restore(sess,data_path) # 导入变量值

也可以使用

data_path = './checkpoint/'  
    saver.restore(sess,tf.train.latest_checkpoint(data_path)) # 导入变量值

注意这里 data_path 里面的路径: ‘./checkpoint/model.ckpt’

3.使用graph.get_tensor_by_name()根据张量名称获取你想要的张量。因为我在训练的时候设置占位符时,将我的输入张量名字定为x,如下图所示。

def create_placeholders(n_x, n_y):
    X = tf.placeholder(tf.float32, shape=[n_x, None],name="x")
    Y = tf.placeholder(tf.float32, shape=[n_y, None],name="y")
    return X, Y

所以恢复输入张量

x=graph.get_operation_by_name('x').outputs[0]

3.因为前向传播过程中,没有为我们计算的输出y_hat设置张量。前向传播中,我们只计算到z3,所以我们用tf.add_to_collection() 来加载一下你预测要用的参数。这里我需要Z3,我就在训练代码中收集Z3,并给他命名为’pred_network’。

Z3 = forward_propagation(X, parameters)
    tf.add_to_collection('pred_network', Z3)

然后预测时恢复

z3=tf.get_collection("pred_network")[0]

然后在会话中计算z3

z3=sess.run(z3, feed_dict={x:my_image})

完整代码如下

import numpy as np
import tensorflow as tf
import scipy
from PIL import Image
from scipy import ndimage
import matplotlib.pyplot as plt
fname = "images/thumbs_up.jpg"
image = np.array(ndimage.imread(fname, flatten=False))
my_image = scipy.misc.imresize(image, size=(64,64)).reshape((1, 64*64*3)).T
with tf.Session() as sess:  
    meta_path = './checkpoint/model.ckpt.meta'  
    model_path = './checkpoint/model.ckpt'  
    saver = tf.train.import_meta_graph(meta_path) # 导入图  
    saver.restore(sess,model_path) # 导入变量值  
    graph = tf.get_default_graph()
    x=graph.get_operation_by_name('x').outputs[0]
    print(x)
    z3=tf.get_collection("pred_network")[0]
    z3=sess.run(z3, feed_dict={x:my_image})
    y=np.argmax(z3)
    print(str(np.squeeze(y)))

这样就完成预测啦。如果用jupyter notebook在预测的时候报没有数据喂入的错,可能重启一下jupyter就好了。