文章目录
- 1.导入包
- 2.加载数据
- 3.可视化数据集样本
- 4.数据集划分
- 5.优化数据集
- 6.数据增强
- 7.调整像素值
- 8.构建base_model
- 9.冻结
- 10.添加分类头部部分
- 11.添加预测部分
- 12.链接各部分
- 13.编译模型
- 14.训练模型
- 15.学习曲线
- 16.第二种方法-fine-tuning
- 17 Evaluation and prediction
- 18 总结
Transfer learning and fine-tuning
- 我们将学习如何通过使用来自预训练网络的迁移学习来对猫和狗的图像进行分类。
- 预先训练的模型是一个保存以前在大型数据集上进行训练的网络. 可以按原样使用预训练的模型,也可以使用迁移学习针对给定任务自定义该模型。
- 迁移学习进行图像分类的直觉是,如果在足够大且足够通用的数据集上训练模型,则该模型将有效地充当视觉世界的通用模型。然后,您可以利用这些学习的功能图,而不必通过在大型数据集上训练大型模型而从头开始。
我们尝试两种方法来定制预训练的模型: - 特征提取:使用先前网络学到的表示法从新样本中提取有意义的特征。您只需在预先训练的模型之上添加一个新的分类器,即可从头开始对其进行训练,从而可以重新利用先前为数据集学习的特征图。您不需要(重新)训练整个模型。基本的卷积网络已经包含了通常用于图片分类的功能。但是,预训练模型的最终分类部分特定于原始分类任务,随后特定于训练模型的一类。
2.微调:取消冻结已冻结模型基础的一些顶层,并共同训练新添加的分类器层和基础模型的最后一层。这使我们可以“微调”基本模型中的高阶特征表示,以使其与特定任务更加相关。
本文遵循一般的机器学习工作流程: - 检查并了解数据
- 使用Keras ImageDataGenerator建立输入管道
- 构建模型
- 预训练基本模型中的加载(和预训练权重)
- 将分类层堆叠在顶部
- 训练模型
- 评估模型
1.导入包
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
2.加载数据
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
BATCH_SIZE = 32
IMG_SIZE = (160, 160)
train_dataset = image_dataset_from_directory(train_dir,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
validation_dataset = image_dataset_from_directory(validation_dir,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
3.可视化数据集样本
#Show the first nine images and labels from the training set:
class_names = train_dataset.class_names
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
4.数据集划分
#As the original dataset doesn't contains a test set, you will create one. To do so, determine how many batches of data are available in the validation set using tf.data.experimental.cardinality, then move 20% of them to a test set.
val_batches = tf.data.experimental.cardinality(validation_dataset)#查看批次大小
test_dataset = validation_dataset.take(val_batches // 5)#测试集取20%
validation_dataset = validation_dataset.skip(val_batches // 5)#其余作验证集
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6
5.优化数据集
#Use buffered prefetching to load images from disk without having I/O become blocking. 用预存数据prefetch来提升性能
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)
6.数据增强
data_augmentation = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
#Let's repeatedly apply these layers to the same image and see the result.
for image, _ in train_dataset.take(1):
plt.figure(figsize=(10, 10))
first_image = image[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
plt.imshow(augmented_image[0] / 255)
plt.axis('off')
7.调整像素值
#In a moment, you will download tf.keras.applications.MobileNetV2 for use as your base model. This model expects pixel vaues in [-1,1], but at this point, the pixel values in your images are in [0-255].
#To rescale them, use the preprocessing method included with the model.
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
#Note: Alternatively, you could rescale pixel values from [0,255] to [-1, 1] using a Rescaling layer.
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)
8.构建base_model
我们将根据Google开发的MobileNet V2模型创建基本模型。
这是已在ImageNet数据集上进行了预训练,该图像数据集是一个包含140万个图像和1000个类别的大型数据集。
ImageNet是一个研究训练数据集,具有多种类别,例如菠萝蜜和注射器。这些知识将帮助我们从特定数据集中对猫和狗进行分类。
首先,您需要选择将用于功能提取的MobileNet V2的哪一层。最后的分类层(在“顶部”,因为大多数机器学习模型的图表从底部到顶部)不是很有用。
取而代之的是,您将遵循惯例,在展平操作之前依赖于最后一层。该层称为“瓶颈层”。与最终/顶层相比,瓶颈层的特征保留了更多的通用性。
首先,实例化一个预加载了ImageNet训练权重的MobileNet V2模型。
通过指定include_top = False参数,可以加载不在顶部包括分类层的网络,这对于特征提取是理想的。
# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
#This feature extractor converts each 160x160x3 image into a 5x5x1280 block of features. Let's see what it does to an example batch of images:
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)
9.冻结
- 特征提取
在这一步中,您将冻结上一步创建的卷积基础,并用作特征提取器。 此外,您可以在其顶部添加分类器并训练顶级分类器。
- 冻结卷积基础
在编译和训练模型之前,冻结卷积基础很重要。 冻结(通过设置layer.trainable = False)可防止在训练期间更新给定层中的权重。 MobileNet V2具有许多层,因此将整个模型的可训练标记设置为False将会冻结所有这些层。
base_model.trainable = False
# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 160, 160, 3) 0
__________________________________________________________________________________________________
Conv1 (Conv2D) (None, 80, 80, 32) 864 input_1[0][0]
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization) (None, 80, 80, 32) 128 Conv1[0][0]
__________________________________________________________________________________________________
Conv1_relu (ReLU) (None, 80, 80, 32) 0 bn_Conv1[0][0]
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32) 288 Conv1_relu[0][0]
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32) 128 expanded_conv_depthwise[0][0]
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32) 0 expanded_conv_depthwise_BN[0][0]
__________________________________________________________________________________________________
expanded_conv_project (Conv2D) (None, 80, 80, 16) 512 expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16) 64 expanded_conv_project[0][0]
__________________________________________________________________________________________________
block_1_expand (Conv2D) (None, 80, 80, 96) 1536 expanded_conv_project_BN[0][0]
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96) 384 block_1_expand[0][0]
__________________________________________________________________________________________________
block_1_expand_relu (ReLU) (None, 80, 80, 96) 0 block_1_expand_BN[0][0]
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D) (None, 81, 81, 96) 0 block_1_expand_relu[0][0]
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96) 864 block_1_pad[0][0]
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96) 384 block_1_depthwise[0][0]
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU) (None, 40, 40, 96) 0 block_1_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_1_project (Conv2D) (None, 40, 40, 24) 2304 block_1_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24) 96 block_1_project[0][0]
__________________________________________________________________________________________________
block_2_expand (Conv2D) (None, 40, 40, 144) 3456 block_1_project_BN[0][0]
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_2_expand[0][0]
__________________________________________________________________________________________________
block_2_expand_relu (ReLU) (None, 40, 40, 144) 0 block_2_expand_BN[0][0]
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144) 1296 block_2_expand_relu[0][0]
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144) 576 block_2_depthwise[0][0]
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU) (None, 40, 40, 144) 0 block_2_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_2_project (Conv2D) (None, 40, 40, 24) 3456 block_2_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24) 96 block_2_project[0][0]
__________________________________________________________________________________________________
block_2_add (Add) (None, 40, 40, 24) 0 block_1_project_BN[0][0]
block_2_project_BN[0][0]
__________________________________________________________________________________________________
block_3_expand (Conv2D) (None, 40, 40, 144) 3456 block_2_add[0][0]
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_3_expand[0][0]
__________________________________________________________________________________________________
block_3_expand_relu (ReLU) (None, 40, 40, 144) 0 block_3_expand_BN[0][0]
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D) (None, 41, 41, 144) 0 block_3_expand_relu[0][0]
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144) 1296 block_3_pad[0][0]
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144) 576 block_3_depthwise[0][0]
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU) (None, 20, 20, 144) 0 block_3_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_3_project (Conv2D) (None, 20, 20, 32) 4608 block_3_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32) 128 block_3_project[0][0]
__________________________________________________________________________________________________
block_4_expand (Conv2D) (None, 20, 20, 192) 6144 block_3_project_BN[0][0]
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_4_expand[0][0]
__________________________________________________________________________________________________
block_4_expand_relu (ReLU) (None, 20, 20, 192) 0 block_4_expand_BN[0][0]
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_4_expand_relu[0][0]
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_4_depthwise[0][0]
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_4_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_4_project (Conv2D) (None, 20, 20, 32) 6144 block_4_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32) 128 block_4_project[0][0]
__________________________________________________________________________________________________
block_4_add (Add) (None, 20, 20, 32) 0 block_3_project_BN[0][0]
block_4_project_BN[0][0]
__________________________________________________________________________________________________
block_5_expand (Conv2D) (None, 20, 20, 192) 6144 block_4_add[0][0]
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_5_expand[0][0]
__________________________________________________________________________________________________
block_5_expand_relu (ReLU) (None, 20, 20, 192) 0 block_5_expand_BN[0][0]
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_5_expand_relu[0][0]
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_5_depthwise[0][0]
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_5_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_5_project (Conv2D) (None, 20, 20, 32) 6144 block_5_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32) 128 block_5_project[0][0]
__________________________________________________________________________________________________
block_5_add (Add) (None, 20, 20, 32) 0 block_4_add[0][0]
block_5_project_BN[0][0]
__________________________________________________________________________________________________
block_6_expand (Conv2D) (None, 20, 20, 192) 6144 block_5_add[0][0]
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_6_expand[0][0]
__________________________________________________________________________________________________
block_6_expand_relu (ReLU) (None, 20, 20, 192) 0 block_6_expand_BN[0][0]
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D) (None, 21, 21, 192) 0 block_6_expand_relu[0][0]
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192) 1728 block_6_pad[0][0]
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192) 768 block_6_depthwise[0][0]
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU) (None, 10, 10, 192) 0 block_6_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_6_project (Conv2D) (None, 10, 10, 64) 12288 block_6_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64) 256 block_6_project[0][0]
__________________________________________________________________________________________________
block_7_expand (Conv2D) (None, 10, 10, 384) 24576 block_6_project_BN[0][0]
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_7_expand[0][0]
__________________________________________________________________________________________________
block_7_expand_relu (ReLU) (None, 10, 10, 384) 0 block_7_expand_BN[0][0]
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_7_expand_relu[0][0]
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_7_depthwise[0][0]
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_7_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_7_project (Conv2D) (None, 10, 10, 64) 24576 block_7_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64) 256 block_7_project[0][0]
__________________________________________________________________________________________________
block_7_add (Add) (None, 10, 10, 64) 0 block_6_project_BN[0][0]
block_7_project_BN[0][0]
__________________________________________________________________________________________________
block_8_expand (Conv2D) (None, 10, 10, 384) 24576 block_7_add[0][0]
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_8_expand[0][0]
__________________________________________________________________________________________________
block_8_expand_relu (ReLU) (None, 10, 10, 384) 0 block_8_expand_BN[0][0]
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_8_expand_relu[0][0]
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_8_depthwise[0][0]
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_8_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_8_project (Conv2D) (None, 10, 10, 64) 24576 block_8_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64) 256 block_8_project[0][0]
__________________________________________________________________________________________________
block_8_add (Add) (None, 10, 10, 64) 0 block_7_add[0][0]
block_8_project_BN[0][0]
__________________________________________________________________________________________________
block_9_expand (Conv2D) (None, 10, 10, 384) 24576 block_8_add[0][0]
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_9_expand[0][0]
__________________________________________________________________________________________________
block_9_expand_relu (ReLU) (None, 10, 10, 384) 0 block_9_expand_BN[0][0]
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_9_expand_relu[0][0]
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_9_depthwise[0][0]
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_9_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_9_project (Conv2D) (None, 10, 10, 64) 24576 block_9_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64) 256 block_9_project[0][0]
__________________________________________________________________________________________________
block_9_add (Add) (None, 10, 10, 64) 0 block_8_add[0][0]
block_9_project_BN[0][0]
__________________________________________________________________________________________________
block_10_expand (Conv2D) (None, 10, 10, 384) 24576 block_9_add[0][0]
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384) 1536 block_10_expand[0][0]
__________________________________________________________________________________________________
block_10_expand_relu (ReLU) (None, 10, 10, 384) 0 block_10_expand_BN[0][0]
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384) 3456 block_10_expand_relu[0][0]
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384) 1536 block_10_depthwise[0][0]
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_10_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_10_project (Conv2D) (None, 10, 10, 96) 36864 block_10_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96) 384 block_10_project[0][0]
__________________________________________________________________________________________________
block_11_expand (Conv2D) (None, 10, 10, 576) 55296 block_10_project_BN[0][0]
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_11_expand[0][0]
__________________________________________________________________________________________________
block_11_expand_relu (ReLU) (None, 10, 10, 576) 0 block_11_expand_BN[0][0]
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_11_expand_relu[0][0]
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_11_depthwise[0][0]
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_11_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_11_project (Conv2D) (None, 10, 10, 96) 55296 block_11_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96) 384 block_11_project[0][0]
__________________________________________________________________________________________________
block_11_add (Add) (None, 10, 10, 96) 0 block_10_project_BN[0][0]
block_11_project_BN[0][0]
__________________________________________________________________________________________________
block_12_expand (Conv2D) (None, 10, 10, 576) 55296 block_11_add[0][0]
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_12_expand[0][0]
__________________________________________________________________________________________________
block_12_expand_relu (ReLU) (None, 10, 10, 576) 0 block_12_expand_BN[0][0]
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_12_expand_relu[0][0]
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_12_depthwise[0][0]
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_12_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_12_project (Conv2D) (None, 10, 10, 96) 55296 block_12_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96) 384 block_12_project[0][0]
__________________________________________________________________________________________________
block_12_add (Add) (None, 10, 10, 96) 0 block_11_add[0][0]
block_12_project_BN[0][0]
__________________________________________________________________________________________________
block_13_expand (Conv2D) (None, 10, 10, 576) 55296 block_12_add[0][0]
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_13_expand[0][0]
__________________________________________________________________________________________________
block_13_expand_relu (ReLU) (None, 10, 10, 576) 0 block_13_expand_BN[0][0]
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D) (None, 11, 11, 576) 0 block_13_expand_relu[0][0]
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576) 5184 block_13_pad[0][0]
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576) 2304 block_13_depthwise[0][0]
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU) (None, 5, 5, 576) 0 block_13_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_13_project (Conv2D) (None, 5, 5, 160) 92160 block_13_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160) 640 block_13_project[0][0]
__________________________________________________________________________________________________
block_14_expand (Conv2D) (None, 5, 5, 960) 153600 block_13_project_BN[0][0]
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_14_expand[0][0]
__________________________________________________________________________________________________
block_14_expand_relu (ReLU) (None, 5, 5, 960) 0 block_14_expand_BN[0][0]
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_14_expand_relu[0][0]
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_14_depthwise[0][0]
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_14_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_14_project (Conv2D) (None, 5, 5, 160) 153600 block_14_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160) 640 block_14_project[0][0]
__________________________________________________________________________________________________
block_14_add (Add) (None, 5, 5, 160) 0 block_13_project_BN[0][0]
block_14_project_BN[0][0]
__________________________________________________________________________________________________
block_15_expand (Conv2D) (None, 5, 5, 960) 153600 block_14_add[0][0]
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_15_expand[0][0]
__________________________________________________________________________________________________
block_15_expand_relu (ReLU) (None, 5, 5, 960) 0 block_15_expand_BN[0][0]
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_15_expand_relu[0][0]
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_15_depthwise[0][0]
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_15_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_15_project (Conv2D) (None, 5, 5, 160) 153600 block_15_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160) 640 block_15_project[0][0]
__________________________________________________________________________________________________
block_15_add (Add) (None, 5, 5, 160) 0 block_14_add[0][0]
block_15_project_BN[0][0]
__________________________________________________________________________________________________
block_16_expand (Conv2D) (None, 5, 5, 960) 153600 block_15_add[0][0]
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_16_expand[0][0]
__________________________________________________________________________________________________
block_16_expand_relu (ReLU) (None, 5, 5, 960) 0 block_16_expand_BN[0][0]
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_16_expand_relu[0][0]
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_16_depthwise[0][0]
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_16_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_16_project (Conv2D) (None, 5, 5, 320) 307200 block_16_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320) 1280 block_16_project[0][0]
__________________________________________________________________________________________________
Conv_1 (Conv2D) (None, 5, 5, 1280) 409600 block_16_project_BN[0][0]
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization) (None, 5, 5, 1280) 5120 Conv_1[0][0]
__________________________________________________________________________________________________
out_relu (ReLU) (None, 5, 5, 1280) 0 Conv_1_bn[0][0]
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________
10.添加分类头部部分
#To generate predictions from the block of features, average over the spatial 5x5 spatial locations, using a tf.keras.layers.GlobalAveragePooling2D layer to convert the features to a single 1280-element vector per image.
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)
11.添加预测部分
#Apply a tf.keras.layers.Dense layer to convert these features into a single prediction per image.
#You don't need an activation function here because this prediction will be treated as a logit, or a raw prediction value.
#Positive numbers predict class 1, negative numbers predict class 0.
prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)
12.链接各部分
#Build a model by chaining together the data augmentation, rescaling, base_model and feature extractor layers using the Keras Functional API.
#As previously mentioned, use training=False as our model contains a BatchNormalization layer.
inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
13.编译模型
#Compile the model before training it.
#Since there are two classes, use a binary cross-entropy loss with from_logits=True since the model provides a linear output.
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 160, 160, 3)] 0
_________________________________________________________________
sequential (Sequential) (None, 160, 160, 3) 0
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3) 0
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3) 0
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280) 2257984
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280) 0
_________________________________________________________________
dropout (Dropout) (None, 1280) 0
_________________________________________________________________
dense (Dense) (None, 1) 1281
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________
The 2.5M parameters in MobileNet are frozen, but there are 1.2K trainable parameters in the Dense layer. These are divided between two tf.Variable objects, the weights and biases.
len(model.trainable_variables)
2
14.训练模型
initial_epochs = 10
loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 7s 199ms/step - loss: 0.9126 - accuracy: 0.3857
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 0.91
initial accuracy: 0.38
history = model.fit(train_dataset,
epochs=initial_epochs,
validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 22s 306ms/step - loss: 0.7643 - accuracy: 0.5285 - val_loss: 0.6197 - val_accuracy: 0.6324
Epoch 2/10
63/63 [==============================] - 19s 301ms/step - loss: 0.5608 - accuracy: 0.6780 - val_loss: 0.4556 - val_accuracy: 0.7710
Epoch 3/10
63/63 [==============================] - 19s 305ms/step - loss: 0.4410 - accuracy: 0.7810 - val_loss: 0.3487 - val_accuracy: 0.8428
Epoch 4/10
63/63 [==============================] - 19s 300ms/step - loss: 0.3661 - accuracy: 0.8270 - val_loss: 0.2757 - val_accuracy: 0.8948
Epoch 5/10
63/63 [==============================] - 19s 297ms/step - loss: 0.3097 - accuracy: 0.8680 - val_loss: 0.2295 - val_accuracy: 0.9134
Epoch 6/10
63/63 [==============================] - 19s 298ms/step - loss: 0.2851 - accuracy: 0.8795 - val_loss: 0.2078 - val_accuracy: 0.9257
Epoch 7/10
63/63 [==============================] - 21s 329ms/step - loss: 0.2658 - accuracy: 0.8865 - val_loss: 0.1758 - val_accuracy: 0.9418
Epoch 8/10
63/63 [==============================] - 19s 304ms/step - loss: 0.2432 - accuracy: 0.8990 - val_loss: 0.1697 - val_accuracy: 0.9369
Epoch 9/10
63/63 [==============================] - 19s 302ms/step - loss: 0.2305 - accuracy: 0.9025 - val_loss: 0.1516 - val_accuracy: 0.9468
Epoch 10/10
63/63 [==============================] - 20s 307ms/step - loss: 0.2104 - accuracy: 0.9170 - val_loss: 0.1451 - val_accuracy: 0.9493
15.学习曲线
#Let's take a look at the learning curves of the training and validation accuracy/loss when using the MobileNet V2 base model as a fixed feature extractor.
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
16.第二种方法-fine-tuning
- 在特征提取实验中,您仅在MobileNet V2基本模型的顶部训练了几层。 训练过程中未更新预训练网络的权重。进一步提高性能的一种方法是在训练(或“微调”)预训练模型顶层的权重的同时,还训练您添加的分类器。 训练过程将迫使权重从通用特征图调整为专门与数据集相关联的特征。
- 仅当您在训练了顶级分类器且将预训练模型设置为不可训练之后,才能尝试执行此操作。 如果您在预训练模型的顶部添加随机初始化的分类器并尝试共同训练所有图层,则梯度更新的幅度将太大(由于分类器的随机权重),因此您的预训练模型将 忘记它学到的东西。
- 另外,您应该尝试微调少量顶层而不是整个MobileNet模型。 在大多数卷积网络中,高层越高,它的专业性就越高。前几层学习非常简单且通用的功能,这些功能可以推广到几乎所有类型的图像。 随着您的上移,这些功能越来越多地针对训练模型的数据集。微调的目的是使这些专用功能适应新数据集,而不是覆盖常规学习。
#Un-freeze the top layers of the model
#All you need to do is unfreeze the base_model and set the bottom layers to be un-trainable. Then, you should recompile the model (necessary for these changes to take effect), and resume training.
base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))
# Fine-tune from this layer onwards
fine_tune_at = 100
# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
Number of layers in the base model: 154
Compile the model
#As you are training a much larger model and want to readapt the pretrained weights, it is important to use a lower learning rate at this stage. Otherwise, your model could overfit very quickly.
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 160, 160, 3)] 0
_________________________________________________________________
sequential (Sequential) (None, 160, 160, 3) 0
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3) 0
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3) 0
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280) 2257984
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280) 0
_________________________________________________________________
dropout (Dropout) (None, 1280) 0
_________________________________________________________________
dense (Dense) (None, 1) 1281
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________
len(model.trainable_variables)
56
Continue training the model
#If you trained to convergence earlier, this step will improve your accuracy by a few percentage points.
fine_tune_epochs = 10
total_epochs = initial_epochs + fine_tune_epochs
history_fine = model.fit(train_dataset,
epochs=total_epochs,
initial_epoch=history.epoch[-1],
validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 31s 426ms/step - loss: 0.1810 - accuracy: 0.9288 - val_loss: 0.0674 - val_accuracy: 0.9752
Epoch 11/20
63/63 [==============================] - 26s 406ms/step - loss: 0.1221 - accuracy: 0.9494 - val_loss: 0.0592 - val_accuracy: 0.9827
Epoch 12/20
63/63 [==============================] - 26s 402ms/step - loss: 0.1116 - accuracy: 0.9529 - val_loss: 0.0732 - val_accuracy: 0.9666
Epoch 13/20
63/63 [==============================] - 26s 402ms/step - loss: 0.0950 - accuracy: 0.9586 - val_loss: 0.0467 - val_accuracy: 0.9790
Epoch 14/20
63/63 [==============================] - 25s 396ms/step - loss: 0.1075 - accuracy: 0.9556 - val_loss: 0.0487 - val_accuracy: 0.9814
Epoch 15/20
63/63 [==============================] - 25s 396ms/step - loss: 0.0664 - accuracy: 0.9741 - val_loss: 0.0435 - val_accuracy: 0.9827
Epoch 16/20
63/63 [==============================] - 25s 398ms/step - loss: 0.0860 - accuracy: 0.9681 - val_loss: 0.0428 - val_accuracy: 0.9790
Epoch 17/20
63/63 [==============================] - 25s 394ms/step - loss: 0.0709 - accuracy: 0.9740 - val_loss: 0.0662 - val_accuracy: 0.9691
Epoch 18/20
63/63 [==============================] - 25s 394ms/step - loss: 0.0787 - accuracy: 0.9685 - val_loss: 0.0390 - val_accuracy: 0.9827
Epoch 19/20
63/63 [==============================] - 25s 394ms/step - loss: 0.0733 - accuracy: 0.9734 - val_loss: 0.0577 - val_accuracy: 0.9728
Epoch 20/20
63/63 [==============================] - 25s 395ms/step - loss: 0.0642 - accuracy: 0.9739 - val_loss: 0.0403 - val_accuracy: 0.9802
acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']
loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
17 Evaluation and prediction
#Finaly you can verify the performance of the model on new data using test set.
loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 1s 188ms/step - loss: 0.0559 - accuracy: 0.9792
Test accuracy : 0.9791666865348816
And now you are all set to use this model to predict if your pet is a cat or dog.
#Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()
# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)
print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)
plt.figure(figsize=(10, 10))
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image_batch[i].astype("uint8"))
plt.title(class_names[predictions[i]])
plt.axis("off")
Predictions:
[0 1 0 0 1 1 1 1 1 0 0 1 0 1 1 1 0 0 1 0 1 0 1 1 0 1 1 0 1 0 0 0]
Labels:
[0 1 0 0 1 1 1 1 1 0 0 1 0 1 1 1 0 0 1 0 1 0 1 0 0 1 1 0 1 0 0 0]
18 总结
1. 使用预训练的模型进行特征提取:使用小型数据集时,通常的做法是利用在相同域中的较大数据集上训练的模型中学习的特征。这是通过实例化预训练模型并在顶部添加完全连接的分类器来完成的。预先训练的模型是“冻结的”,训练过程中仅更新分类器的权重。在这种情况下,卷积基础提取了与每个图像关联的所有特征,而您刚刚训练了一个分类器,该分类器根据给定的提取特征集确定图像类。
2. 微调预训练的模型:为了进一步提高性能,可能需要通过微调将预训练的模型的顶层重新用于新的数据集。在这种情况下,您需要调整权重,以使模型学习到特定于数据集的高级功能。当训练数据集很大并且与训练前的训练模型所依据的原始数据集非常相似时,通常建议使用此技术。