import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
AUTOTUNE=tf.data.experimental.AUTOTUNE#用CPU动态设置并行调用的数量
import pathlib
#示例将从url上下载的数据进行处理
# data_root_orig=keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
#                                          fname='flower_photos', untar=True)#utils.get_file是用于从网上下载数据并且解压

# data_root=pathlib.Path(data_root_orig)#data_root定义为Path类
data_root=pathlib.Path('flower_photos')#如果数据已经下载,就直接用此方法读取本地文件夹

for item in data_root.iterdir():#遍历data_root下的文件
    print(item)  #暑促和当前文件夹下的文件
import random
all_image_paths=list(data_root.glob('*/*'))  #遍历目录中子文件夹下的文件,如果改成*/就是遍历目录中的文件
all_image_paths=[str(path) for path in all_image_paths] #将子目录文件夹下面的文件字符串化并且放入列表中
random.shuffle(all_image_paths)#打乱顺序
image_count=len(all_image_paths)
print(image_count)#输出一共有多少张图片

#确定每张图片的标签
label_names=sorted(item.name for item in\
   data_root.glob('*/') if item.is_dir())#is_dir()检测指定的文件是否是一个目录,如果是则返回TRUE、否则返回FALSE
# 由于遍历返回的是一个生成器,sorted返回的是排序后的副本,并非原来的生成器数据
# is_dir()检测指定的文件是否是一个目录,如果是则返回TRUE、否则返回FALSE
print(label_names) #输出当前样本地标签名字
#为每个标签分配索引
label_to_index=dict((name,index)for index,name in\
        enumerate(label_names))
print(label_to_index) # 该字典已经经各个标签名字打好了对应的类号
# 根据当前图像所在目录的名字对应的字典值来赋予类别号
all_image_labels=[label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
#--------------------------------------------------
#读取一张图像查看图像属性
img_path=all_image_paths[0]
print(img_path)
img_raw=tf.io.read_file(img_path)#读取一行图像
img_tensor=tf.image.decode_image(img_raw)#解析该图像
print(img_tensor.dtype) #输出图像的相关信息
print(img_tensor.shape)
img_final = tf.image.resize(img_tensor, [192, 192])#裁剪
img_final = img_final/255.0#归一化
print(img_final.shape) #归一化后输出相关信息
print(img_final.numpy().min())
print(img_final.numpy().max())
#-----------------------------------------------------
# tf.image.decode_image(),返回dtype类型的Tensor,对于BMP,JPEG和PNG图像其shape为[height, width, num_channels],
# 对于GIF图像,其shape为[num_frames, height, width, 3]。
# 知道图像格式的时候可以使用 tf.image.decode_jpeg或其他格式对应的API
def preprocess_image(image):
    image=tf.image.decode_jpeg(image,channels=3) # 解析图片返回相应的tensor
    image=tf.image.resize(image,[192,192]) #图像裁剪
    image/=255.0    # 归一化处理
    return image

def load_and_preprocess_image(path):#读取并预处理图像
    image=tf.io.read_file(path)
    return preprocess_image(image)

path_ds=tf.data.Dataset.from_tensor_slices(all_image_paths) #对所有的图片切片
image_ds=path_ds.map(load_and_preprocess_image,num_parallel_calls=AUTOTUNE)#利用load_and_preprocess_image对切片的数据预处理

plt.figure(figsize=(8,8))  #输出切片处理后的图像
#输出3张图像,因为此时数据类型为TakeDataSet,所以要用此方式来调取数据
# for n,image in enumerate(image_ds.take(3)):
#     plt.subplot(2,2,n+1)
#     plt.imshow(image)
#     plt.grid(False)
#     plt.xticks([])
#     plt.yticks([])
#     plt.xlabel((all_image_paths[n]))
#     plt.show()

label_ds=tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels,tf.int64))
#显示前10个标签
for label in label_ds.take(10):
    print(label_names[label])

#打包(图片,标签)
image_label_ds=tf.data.Dataset.zip((image_ds,label_ds))
print(image_label_ds)
# 也可以用切片替换tf.data.Dataset.zip
ds=tf.data.Dataset.from_tensor_slices((all_image_paths,all_image_labels))

# def load_and_preprocess_from_path_label(path,label):
#     return load_and_preprocess_image(path),label
# image_label_ds=ds.map(load_and_preprocess_from_path_label)
# print(image_label_ds)

# 将数据打乱,划分为BATCHSIZE大小
BATHC_SIZE=32
#在 .repeat 之后 .shuffle,会在 epoch 之间打乱数据(当有些数据出现两次的时候,其他数据还没有出现过)。
#在 .batch 之后 .shuffle,会打乱 batch 的顺序,但是不会在 batch 之间打乱数据。

ds=image_label_ds.shuffle(buffer_size=image_count)
ds=ds.repeat()
ds=ds.prefetch(buffer_size=AUTOTUNE)# prefetch从数据集中预取数据
print(ds)

#可以通过使用tf.data.Dataset.apply方法
# 和融合过的 tf.data.experimental.shuffle_and_repeat 函数来解决:
ds=image_label_ds.apply(tf.data.experimental.shuffle_and_repeat\
                            (buffer_size=image_count))
ds=ds.batch(BATHC_SIZE)
ds=ds.prefetch(buffer_size=AUTOTUNE)
print(ds)#和前面ds结果一样

#TFRecord文件
image_ds=tf.data.Dataset.from_tensor_slices(all_image_paths).map(tf.io.read_file)
tfrec=tf.data.experimental.TFRecordWriter('images.tfrec')#将图像数据写在images.tfrec中去
tfrec.write(image_ds)
image_ds=tf.data.TFRecordDataset('images.tfrec').map(preprocess_image)
ds=tf.data.Dataset.zip((image_ds,label_ds))
ds=ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds=ds.batch(BATHC_SIZE).prefetch(AUTOTUNE)
print(ds)
#利用tensor序列化加速
paths_ds=tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds=paths_ds.map(load_and_preprocess_image)
print(image_ds)#图像tensor化
# tensor序列化至TFRecord
ds=image_ds.map(tf.io.serialize_tensor)
print(ds)
tfrec=tf.data.experimental.TFRecordWriter('images.tfrec')
tfrec.write(ds)

ds=tf.data.TFRecordDataset('images.tfrec')
def parse(x):  # 对前面写入的数据进行解析
    result=tf.io.parse_tensor(x,out_type=tf.float32)
    result=tf.reshape(result,[192,192,3])
    return result
ds=ds.map(parse,num_parallel_calls=AUTOTUNE) # 解析数据
print(ds)
ds=tf.data.Dataset.zip((ds,label_ds))
ds=ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds=ds.batch(BATHC_SIZE).prefetch(AUTOTUNE)
print(ds)