一 Tensorflow中数据的读取方式
在Tensorflow中,程序读取数据的方式一共有三种:
[1]供给数据读取方式(Feeding):在Tensorflow程序运行的每一步,利用Python代码来供给/提供数据.
[2]从文件读取数据:在Tensorflow图的开始,让一个输入管线从文件中读取相应的数据
[3]预加载数据:在Tensorflow图中定义常量或变量来保存所有的数据,仅适用于数据量比较小的情况
二 供给数据的读取方式:feed_dict参数
Tensorflow的【数据供给机制】,允许你在Tensorflow计算图中,将【数据】注入到任一【张量】中。因此,python运算可以把数据直接设置到Tensorflow的计算图。
通过run()或者eval()函数输入feed_dict参数,可以启动运算过程。如下面的代码所示:
def train(mnist):
#【1】定义tf.placeholder占位符,为Tensorflow的【数据供给机制】做准备
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
y = mnist_inference.inference(x, regularizer)
global_step = tf.Variable(0, trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_,1),logits=y)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,
mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY,staircase=True)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,
global_step=global_step)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(TRAINING_STEPS):
#【2】将从批处理操作中读取到的数据保存爱xs和ys中
xs, ys = mnist.train.next_batch(BATCH_SIZE)
#【3】在Tensorflow中,将计算图中的数据注入到张量x和y_中,根据【Tensorflow的数据供给机制】而得
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 1000 == 0:
print('After {0:d} training step(s), '
'loss on training batch is {1:g} '.format(step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
#========================================================================================================
#模块说明:
# 主程序的入口点
#========================================================================================================
def main(argv=None):
print('[Info]TensorFlow_Version is',tf.__version__) #[1]定义处理MNIST数据集的类,这个类在初始化时会自动下载数据
mnist = input_data.read_data_sets('F:/MnistSet/',one_hot=True)
train(mnist) #[2]调用train函数进行模型的训练
#========================================================================================================
#模块说明:
# Tensorflow提供的一个主程序入口,tf.app.run函数将会调用上面的main函数
#========================================================================================================
if __name__ == '__main__':
tf.app.run()
虽然,我们可以使用【常量:tf.constant】和【变量:tf.Variable】来替换任何一个张量,但最好的做法应该是使用 tf.placeholder占位符节点。Tensorflow设计tf.placeholder节点的唯一的意图就是为了提供【数据的供给方法:Feeding】。tf.placeholder节点被声明的时候是未被初始化的,也不包含数据,如果没有为它提供数据,则Tensorflow进行运算的时候会产生错误,所以,千万不能忘了给tf.placeholder提供数据.
三 从文件读取数据
一个典型的文件读取管线,一般包含下面这些步骤:
1:文件名列表
2:可配置的文件名乱序(shuffling)
3:可配置的最大训练迭代数(Epoch Limit)
4:文件名队列
5:针对输入文件格式的阅读器
6:记录解析器
7:可配置的预处理器
8:样本队列
3.1 文件名列表,乱序,最大迭代训练数
可以使用字符串张量,比如["file0","file1"],[("file%d"%i) for in range(2)]或者tf.train.match_filenames_once函数来产生【文件名列表】。将【文件名列表】交给tf.train.string_input_producer函数,tf.train.string_input_producer函数来生成一个先入先出的队列,文件阅读器在需要时,借助它来读取数据。
tf.train.string_input_producer函数提供的【可配置参数】来【设置文件名乱序】和【最大的迭代数】,QueueRunner会为每次迭代(Epoch)将【所有的文件名】加入到【文件队列中】,如果shuffle=True的话,会对【文件名】进行【乱序处理】。这一过程是比较均匀的,因此,它可以产生【均衡的文件名队列】。
这个QueueRunner的工作线程是独立于文件阅读器的线程,因此,【乱序】和将文件名推入到【文件名队列】这些过程不会阻塞【文件阅读器】运行。
3.2 文件格式
根据你的【文件格式】,选择对应的【文件阅读器】,然后,将【文件名队列】提供给【阅读器】的read的方法。阅读器的read方法会输出一个【key】来表征输入的文件和其中的纪录,同时得到一个【字符串标量】,这个【字符串标量】可以被一个或多个解析器或者【转换操作】将其解码为【张量】并且【构造成样本】。
3.3 CSV文件的创建于读取
CSV文件是最常用的一个文件存储方式。逗号分隔值(Comma-Sepaeated Values,CSV)文件以纯文本形式存储表格数据。纯文本意味着改文件是一个字符序列,不包含必须向二进制数字那样被解读的数据。CSV文件由任意数目的【记录】组成,记录间以某种换行符分割;【每条记录】由【字段】组成,【字段间的分隔符】是其他字符或字符串,最常见的是逗号和制表符。通常,所有记录都有完全相同的字段序列。
3.3.1 CSV文件的创建
对于CSV文件的创建,Python语言有较好的方法对其进行实现,而这里只需要按需求对其格式进行整理即可。
再此块,Tensorflow的CSV文件读取主要用作对【所需要加载文件】的【地址】和【标签】进行【记录】。示例代码如下所示:
import os #[1]os模块是一个Python的系统编程的操作模块,可以处理【文件】和【目录】这些我们日常需要手动做的操作
path = 'jpg'
filename = os.listdir(path)#[2]列出【当前目录】下的【文件】和【文件夹】
strText = ""
with open("train_list.csv","w") as fid:
for a in range(len(filename)):
strText = path+os.sep+filename[a]+","+filename[a].split('_')[0]+"\n"
fid.write(strText)
fid.close()
3.3.2 CSV文件的读取
对应在Tensorflow中使用CSV文件,则需要使用特殊的CSV读取。这通常是为了读取硬盘上图片文件而使用的,方便Tensorflow框架在使用时能够一边读取图片一边对图片数据进行处理。这样做的好处是能够防止一次性读入过多的数据造成框架资源被耗尽。
对于从CSV文件中读取数据,需要使用TextLineReader和
#========================================================================================================
#模块说明:
# [1]tf.train.string_input_producer函数用来生成【一个先入先出的文件队列】,【文件阅读器】在需要的时候借助它
# 来读取数据.
# [2]tf.TextLineReader()实例化一个【文本行阅读器类】的【类对象】,这个类中对应的read函数,每一次读取一行内容,
# 并且使用tf.decode_csv函数来对读取的每一行内容进行解析
#========================================================================================================
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.concat(0, [col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col5])
coord.request_stop()
coord.join(threads)
tf.TextLineReader类中的read函数,每次都会从文件中读取一行内容,decode_csv函数会解析这一行的内容并将其转化为张量列表。如果输入的参数有缺失,record_defaults参数可以根据张量的类型来设置默认值(这一功能非常方便).在调用run函数或者eval去执行read之前,我们必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。
3.4 从二进制文件中读取固定长度
从二进制文件中读取固定长度的记录,可以使用tf.FixedLengthRecordReader的tf.decode_raw操作。decode_raw操作可以将【一个字符串】转化为一个uint8的张量.
举例来说,CIFAR-10的文件格式定义是:每条记录的长度都是固定的,一个字节的标签,后面的3072字节的图像数据。uint8的张量的标准操作就可以从中获取图像并且根据需要进行重组。
3.5 标准的Tensorflow数据格式
在Tensorflow中,另一种保存记录的方法,可以允许你将任意的数据转换为Tensorflow所支持的格式,这种方法可以使Tensorflow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords问价,TFRecords文件包含了tf.train.Example协议内存块(protocol buffer)(协议内存块包含了字段Features)。你可以写一段代码获取你的数据,将数据填入到Example协议内存块中,将【协议内存块】序列化为一个【字符串】,并且通过tf.python_io.TRecordReader类写入到TFRecord文件张。
从TFRecordes文件中读取数据,可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块解析为张量。MNIST的例子就使用了convert_to_records所构建的数据。
3.5.1 TFRecord格式介绍
TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下代码给出了tf.train.Example的定义。
message Example
{
Features features = 1;
};
message Features
{
map<string,Feature> feature = 1;
};
message Feature
{
oneof kind
{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
}
从以上代码中可以看出,tf.train.Example的数据结构是比较简单的。tf.train.Example中包含了一个从【属性名称】到【取值】的字典。其中【属性名称】为【一个字符串】,【属性的取值】可以为【字符串ByteList】、【实数列表FloatList】或者【 整数列表Int64List】。比如将【一张解码前的图像】存为【一个字符串】,【图像所对应的类别编号】存为【整数列表】。
3.5.2 TFRecord样例程序
此块,将给出一个具体的样例程序来读写TFRecord文件。下面的这个程序架构给出如何将MNIST输入数据转化为TFRecord的格式。
#========================================================================================================
#文件说明:
# 【1】此程序可以将MNSIT数据集中所有的训练数据存储到一个TFRecord文件中。当数据量较大时,也可以将数据写入多个TFRecord
# 文件。Tensorflow对从文件列表读取数据提供了很好的支持。
# 【2】将数据存储到TFRecord文件中
#========================================================================================================
import tensorflow as tf
import numpy as np
import tensorflow.examples.tutorials.mnist.input_data as input_data
#========================================================================================================
#函数说明:
# 将数据转换为一个属性,生成整数型的属性
#========================================================================================================
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#========================================================================================================
#函数说明:
# 将数据转换为一个属性,生成字符串型的属性
#========================================================================================================
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
mnist = input_data.read_data_sets('F:/MnistSet/',one_hot=True)
filename = 'F:/MnistSet/output.tfrecords'
images = mnist.train.images
labels = mnist.train.labels #[1]训练数据所对应的正确答案,可以作为一个【属性】保存在TFRecord中
pixels = images.shape[1] #[2]训练数据的图像分辨率,这里可以作为Example中的一个属性
num_examples = mnist.train.num_examples
#[3]创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw= images[index].tostring() #[1]将图像矩阵转化为一个字符串
#[2]将一个样例转化为Example Protocol Buffer,并将所有的信息写入这个数据结构
example = tf.train.Example(features=tf.train.Features(feature={
'pixels':_int64_feature(pixels),
'label' :_int64_feature(np.argmax(labels[index])),
'image_raw':_bytes_feature(image_raw)
}))
writer.write(example.SerializeToString())
writer.close()
#========================================================================================================
#文件说明:
# 从TFRecord文件中读取数据
#========================================================================================================
import tensorflow as tf
import numpy as np
import tensorflow.examples.tutorials.mnist.input_data as input_data
reader = tf.TFRecordReader() #[1]创建一个reader来读取TFRecord文件中的样例
#[2]创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['F:/MnistSet/output.tfrecords'])
#[3]从文件中读出一个样例,也可以使用read_up_to函数一次性读取多个样例
_,serialized_example= reader.read(filename_queue)
#[4]解析读入的一个样例,如果需要解析多个样例,可以使用parse_example函数
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
})
#[5]tf.decode_raw函数可以将字符串解析成图像对应的像素数组
images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32)
sess = tf.Session()
coord= tf.train.Coordinator() #[6]启动多线程处理输入数据
threads= tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(10): #[7]每次运行,可以读取TFRecord文件中的一个样例,当所有的样例
# 读完之后,会重新从头读取
image,label,pixel = sess.run([images,labels,pixels])
print(image,label,pixel)