​keras分类猫狗数据(上)数据预处理​​​
​​​keras分类猫狗数据(中)使用CNN分类模型​​​
​​​keras分类猫狗数据(下)迁移学习​​​
​​​keras分类猫狗数据(番外篇)深度学习CNN连接SVM分类​

1 .使用keras.applications中的vgg16网络模型进行特征提取,并自定义两个全连接层输出分类。

from keras.applications import VGG16
from keras import models,layers,optimizers
from keras.callbacks import TensorBoard

conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(128,128,3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

conv_base.trainable=False

model.summary()

model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['acc'])

import catvsdogs.morph as mp

model.fit_generator(
mp.train_flow,
steps_per_epoch=32,
epochs=50,
validation_data=mp.test_flow,
validation_steps=32,callbacks=[TensorBoard(log_dir='logs/3')])
model.save_weights('outputs/weights_vgg16_use.h5')

keras分类猫狗数据(下)finetune_迭代


keras分类猫狗数据(下)finetune_迭代_02


在30多轮迭代后,测试正确率达到88%。

2 . 微调,使vgg16模型的最后一个卷积层也参与训练,本次使用上文保存的训练权重集​​weights_vgg16_use.h5​​加速训练过程,并使用较小的学习率。

from keras.applications import VGG16
from keras import models,layers,optimizers
from keras.callbacks import TensorBoard

conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(128,128,3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

model.load_weights('outputs/weights_vgg16_use.h5')

conv_base.trainable=True
trainable=False
for layer in conv_base.layers:
if layer.name=='block5_conv1':
trainable=True
layer.trainable=trainable
model.summary()

model.compile(optimizer=optimizers.adam(lr=1e-5),loss='binary_crossentropy',metrics=['acc'])

import catvsdogs.morph as mp

history = model.fit_generator(
mp.train_flow,
steps_per_epoch=32,
epochs=50,
validation_data=mp.test_flow,
validation_steps=32,callbacks=[TensorBoard(log_dir='logs/4')])

keras分类猫狗数据(下)finetune_数据_03


keras分类猫狗数据(下)finetune_数据_04


上图蓝色为本文过程1的,红色为过程2的,正确率到达90%。本文只使用了2000+1000的数据,迭代次数较少,如果想打算更高的识别率,可以简单修改。