一、引言
经过一段时间tensorflow的学习,对全连接神经网络代码的框架有了一定的了解。通过dropout实验来解析代码结构。
二、代码释义
1、定义神经网络层
def add_layer_dropput(input, in_size, out_size,keep_prob = None,activation_function=None):
Weights = Variable(random_normal([out_size, in_size]))
Biases = Variable(zeros([out_size, 1]) + 0.1)
Wx_plus_b = matmul(Weights, input) + Biases
Wx_plus_b = nn.dropout(Wx_plus_b,keep_prob= keep_prob)
if activation_function == None:
output = Wx_plus_b
else:
output = activation_function(transpose(Wx_plus_b))
output = transpose(output)
return output
在其中定义一个神经网络层应有的权重Weights,偏置Biaes,输出Wx_plus_b,激活函数输出output。数据的格式是一列为一个样本,格式详见bp反向传播+3层全连接神经网络+softmax交叉熵损失+代码实现详解,相应的权重和偏置都会有所调整。至于output这里为什么会有转置,请转到Tensorflow解决训练误差大于测试误差查看。
3、定义神经网络类
class neural_network_3(object):
def __init__(self,in_size,hidden_size,out_size,activation_function1,activation_function2):
self.xs = placeholder(float32,[in_size,None])
self.ys = placeholder(float32,[out_size,None])
self.keep_prob = placeholder(float32)
#########define the neural_network_layer#########
self.l1 = add_layer_dropput(self.xs,in_size, hidden_size,keep_prob=self.keep_prob,activation_function=activation_function1)
self.prediction = add_layer_dropput(self.l1,hidden_size,out_size,keep_prob=self.keep_prob,activation_function=activation_function2)
#########define loss#########
self.loss = reduce_mean(-reduce_sum(mul(self.ys,log(self.prediction)),reduction_indices=[0]))#cross_entropy
scalar_summary('loss',self.loss)
########define train function#########
self.train_step = train.GradientDescentOptimizer(0.5).minimize(self.loss)
神经网络类,也就是一个完整的神经网络了,他是由多层神经网络层以及损失函数等组成。
self.xs和self.ys的格式是tensorflow的占位符形式。其中的None是为了适应不同的数据量,也就是样本的多少。self.keep_prob是dropout中的"保有率",也就是保持一定概率神经网络输出不变。self.l1以及self.prediction就是调用了add_layer_dropput()形成的实例。可见这个神经网络有三层,分别是输入层self.xs,隐含层self.l1以及输出层self.prediction。self.loss是计算的损失,reduce_mean是对所有样本求平均。scalar_summary的操作是将loss的数据"总结",以供tensorboard查看loss趋势。train.GradientDescentOptimizer(0.5).minimize(self.loss)是tensorflow自带的训练方法。免去了在bp反向传播+3层全连接神经网络+softmax交叉熵损失+代码实现详解中手动求梯度以及更新参数的麻烦,不得不说这个操作很强大!
3、删除文件夹(仅tensorboard可用)
def delet(path):
if os.path.exists(path): # 如果文件存在
shutil.rmtree(path)
else:
print('no such file') # 则返回文件不存在
因为含有我的主程序中对每次运行的文件通过以时间命名的方式保存,以方便比较效果。因此不免会有很多冗余文件,因此可以在主程序中调用这个函数来删除一些不用的文件。
4、主程序
from tensorflow import nn,Session,initialize_all_variables,train,merge_all_summaries
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from modular.neural_network_class import neural_network_3
from numpy import transpose
from datetime import datetime
from modular.delet import delet
#delet Previous folders
path='/home/xiaoshumiao/PycharmProjects/tensorflow/main/logs/'
delet(path)
#obtain now time
TIMESTAMP = "{0:%Y-%m-%d,%H-%M-%S/}".format(datetime.now())
#load data
digits = load_digits()
X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)#such as onhot
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3)#split data
#add neural_layer
neural = neural_network_3(64,50,10,nn.tanh,nn.softmax)
#initialize variables
init = initialize_all_variables()
with Session() as sess:
# define the regard_value
merged = merge_all_summaries()
# define folder which will write in
train_writer = train.SummaryWriter('logs/train/'+TIMESTAMP,sess.graph)
test_writer = train.SummaryWriter('logs/test/'+TIMESTAMP,sess.graph)
# run initialize variables
sess.run(init)
for i in range(500):
# train
sess.run(neural.train_step, feed_dict={neural.xs: transpose(X_train), neural.ys: transpose(y_train),neural.keep_prob: 0.5})
# write in folder
if i % 50 == 0:
test_result = sess.run(merged, feed_dict={neural.xs: transpose(X_test), neural.ys: transpose(y_test),neural.keep_prob: 1})
train_result = sess.run(merged, feed_dict={neural.xs: transpose(X_train), neural.ys: transpose(y_train),neural.keep_prob: 1})
test_writer.add_summary(test_result,i)
train_writer.add_summary(train_result,i)
程序注释都有,需要注意的是,在训练时,我们选择的keep_prob为0.5,在实际测试的时候值为1。
最终效果如下:
1、无dropout
2、dropout
可以看到,在dropout作用下,过拟合现象明显减轻甚至消失。
May the Force be with you.