​模型的基础上,结合花朵识别的具体问题重新训练该模型,以获取自己需要的tensorflow模型。

重新训练Inception v3实质是在原有模型输出层后,新加了一个输出层作为最终的输出层,我们只训练这个新加的输出层。这里使用了迁移学习的概念。

Transfer learning, which means we are starting with a model that has been already trained on another problem. We will then be retraining it on a similar problem. Deep learning from scratch can take days, but transfer learning can be done in short order.

准备

本节主要给出了训练tensorflow模型的一些前提条件。

硬件环境

  • Ubuntu 16.04

安装tensorflow

安装git

$ sudo apt-get update
$ sudo

准备训练样本

$ cd ~
$ mkdir tf_files
$ cd tf_files
$ curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
$ tar xzf flower_photos.tgz
$ ls flower_photos

flower_photos.tgz有218MB。

[可选操作]

$ cd ~/tf_files
$ ls flower_photos/roses | wc -l
$ rm flower_photos/*/[3-9]* # 删除70%的样本数量,减少训练时间。
$ ls flower_photos/roses | wc -l

开始训练

下载retrain脚本

该脚本会自动下载google Inception v3 模型相关文件。

$ cd ~/tf_files
$ curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/r1.1/tensorflow/examples/image_retraining/retrain.py

启动tensorboard

$ cd ~/tf_files
$ tensorboard --logdir training_summaries &

Note:
This command will fail with the following error if you already have a tensorboard running:
ERROR:tensorflow:TensorBoard attempted to bind to port 6006, but it was already in use
You can kill all existing TensorBoard instances with: ​​​$ pkill -f "tensorboard"​

启动训练脚本

$ cd ~/tf_files
$ python retrain.py \
--bottleneck_dir=bottlenecks \
--how_many_training_steps=500 --model_dir=inception \
--summaries_dir=training_summaries/basic \
--output_graph=retrained_graph.pb \
--output_labels=retrained_labels.txt \
--image_dir=flower_photos

如果不添加​​--how_many_training_steps=500​​,默认值为4000。

启动浏览器查看tensorboard

等待​​~/tf_files/bottlenecks​​​中的bottlenecks文件生成结束后,可以启动浏览器,在地址栏中输入​​localhost:6006​​并回车,来查看训练进度。

小结

The retraining script will write out a version of the Inception v3 network with a final layer retrained to your categories to ​​tf_files/retrained_graph.pb​​​ and a text file containing the labels to ​​tf_files/retrained_labels.txt​​​.
该图像识别模型,训练后的图像识别准确率应该在85%到99%。

测试重新训练的模型

$ cd ~/tf_files
$ curl -L https://goo.gl/3lTKZs > label_image.py
$ python label_image.py flower_photos/roses/2414954629_3708a1a04d.jpg

你应该看到类似以下的结果:

daisy (score = 0.99071)
sunflowers (score = 0.00595)
dandelion (score = 0.00252)
roses (score = 0.00049)
tulips (score = 0.00032)

参考

​TensorFlow For Poets​