文章结构
在GitHub上找到一个不错的代码:https://github.com/DrSleep/tensorflow-deeplab-resnet 本文主要介绍该程序的两个主要文件:
前言:
一、网络结构:
二、train.py:
三、image_reader.py
程序中使用resnet101作为基本模型:
前言:
代码的model.py,network.py是建立深度学习网络的部分,这部分代码风格与Faster-RCNN_TF那个程序的风格非常相似,也很简单,不再多做介绍。这里主要介绍train.py、image_reader.py其他还有inference.py、utils.py、fine_tune.py就不做介绍了,比较简单。
一、网络结构:
对代码的Network.py稍作修改,使得它打印出各层的网络的输出如下:
因为单张1080显卡现存有限,我使用的batch设置为4,输入图片是384*384大小的,但是第一个卷积的步长设置为了2
conv1 (4, 192, 192, 64)
bn_conv1 (4, 192, 192, 64)
pool1 (4, 192, 192, 64)
res2a_branch1 (4, 96, 96, 256)
bn2a_branch1 (4, 96, 96, 256)
res2a_branch2a (4, 96, 96, 64)
bn2a_branch2a (4, 96, 96, 64)
res2a_branch2b (4, 96, 96, 64)
bn2a_branch2b (4, 96, 96, 64)
res2a_branch2c (4, 96, 96, 256)
bn2a_branch2c (4, 96, 96, 256)
res2a_relu (4, 96, 96, 256)
res2b_branch2a (4, 96, 96, 64)
bn2b_branch2a (4, 96, 96, 64)
res2b_branch2b (4, 96, 96, 64)
bn2b_branch2b (4, 96, 96, 64)
res2b_branch2c (4, 96, 96, 256)
bn2b_branch2c (4, 96, 96, 256)
res2b_relu (4, 96, 96, 256)
res2c_branch2a (4, 96, 96, 64)
bn2c_branch2a (4, 96, 96, 64)
res2c_branch2b (4, 96, 96, 64)
bn2c_branch2b (4, 96, 96, 64)
res2c_branch2c (4, 96, 96, 256)
bn2c_branch2c (4, 96, 96, 256)
res2c_relu (4, 96, 96, 256)
res3a_branch1 (4, 48, 48, 512)
bn3a_branch1 (4, 48, 48, 512)
res3a_branch2a (4, 48, 48, 128)
bn3a_branch2a (4, 48, 48, 128)
res3a_branch2b (4, 48, 48, 128)
bn3a_branch2b (4, 48, 48, 128)
res3a_branch2c (4, 48, 48, 512)
bn3a_branch2c (4, 48, 48, 512)
res3a_relu (4, 48, 48, 512)
res3b1_branch2a (4, 48, 48, 128)
bn3b1_branch2a (4, 48, 48, 128)
res3b1_branch2b (4, 48, 48, 128)
bn3b1_branch2b (4, 48, 48, 128)
res3b1_branch2c (4, 48, 48, 512)
bn3b1_branch2c (4, 48, 48, 512)
res3b1_relu (4, 48, 48, 512)
res3b2_branch2a (4, 48, 48, 128)
bn3b2_branch2a (4, 48, 48, 128)
res3b2_branch2b (4, 48, 48, 128)
bn3b2_branch2b (4, 48, 48, 128)
res3b2_branch2c (4, 48, 48, 512)
bn3b2_branch2c (4, 48, 48, 512)
res3b2_relu (4, 48, 48, 512)
res3b3_branch2a (4, 48, 48, 128)
bn3b3_branch2a (4, 48, 48, 128)
res3b3_branch2b (4, 48, 48, 128)
bn3b3_branch2b (4, 48, 48, 128)
res3b3_branch2c (4, 48, 48, 512)
bn3b3_branch2c (4, 48, 48, 512)
res3b3_relu (4, 48, 48, 512)
res4a_branch1 (4, 48, 48, 1024)
bn4a_branch1 (4, 48, 48, 1024)
res4a_branch2a (4, 48, 48, 256)
bn4a_branch2a (4, 48, 48, 256)
res4a_branch2b (4, 48, 48, 256)
bn4a_branch2b (4, 48, 48, 256)
res4a_branch2c (4, 48, 48, 1024)
bn4a_branch2c (4, 48, 48, 1024)
res4a_relu (4, 48, 48, 1024)
res4b1_branch2a (4, 48, 48, 256)
bn4b1_branch2a (4, 48, 48, 256)
res4b1_branch2b (4, 48, 48, 256)
bn4b1_branch2b (4, 48, 48, 256)
res4b1_branch2c (4, 48, 48, 1024)
bn4b1_branch2c (4, 48, 48, 1024)
res4b1_relu (4, 48, 48, 1024)
res4b2_branch2a (4, 48, 48, 256)
bn4b2_branch2a (4, 48, 48, 256)
res4b2_branch2b (4, 48, 48, 256)
bn4b2_branch2b (4, 48, 48, 256)
res4b2_branch2c (4, 48, 48, 1024)
bn4b2_branch2c (4, 48, 48, 1024)
res4b2_relu (4, 48, 48, 1024)
res4b3_branch2a (4, 48, 48, 256)
bn4b3_branch2a (4, 48, 48, 256)
res4b3_branch2b (4, 48, 48, 256)
bn4b3_branch2b (4, 48, 48, 256)
res4b3_branch2c (4, 48, 48, 1024)
bn4b3_branch2c (4, 48, 48, 1024)
res4b3_relu (4, 48, 48, 1024)
res4b4_branch2a (4, 48, 48, 256)
bn4b4_branch2a (4, 48, 48, 256)
res4b4_branch2b (4, 48, 48, 256)
bn4b4_branch2b (4, 48, 48, 256)
res4b4_branch2c (4, 48, 48, 1024)
bn4b4_branch2c (4, 48, 48, 1024)
res4b4_relu (4, 48, 48, 1024)
res4b5_branch2a (4, 48, 48, 256)
bn4b5_branch2a (4, 48, 48, 256)
res4b5_branch2b (4, 48, 48, 256)
bn4b5_branch2b (4, 48, 48, 256)
res4b5_branch2c (4, 48, 48, 1024)
bn4b5_branch2c (4, 48, 48, 1024)
res4b5_relu (4, 48, 48, 1024)
res4b6_branch2a (4, 48, 48, 256)
bn4b6_branch2a (4, 48, 48, 256)
res4b6_branch2b (4, 48, 48, 256)
bn4b6_branch2b (4, 48, 48, 256)
res4b6_branch2c (4, 48, 48, 1024)
bn4b6_branch2c (4, 48, 48, 1024)
res4b6_relu (4, 48, 48, 1024)
res4b7_branch2a (4, 48, 48, 256)
bn4b7_branch2a (4, 48, 48, 256)
res4b7_branch2b (4, 48, 48, 256)
bn4b7_branch2b (4, 48, 48, 256)
res4b7_branch2c (4, 48, 48, 1024)
bn4b7_branch2c (4, 48, 48, 1024)
res4b7_relu (4, 48, 48, 1024)
res4b8_branch2a (4, 48, 48, 256)
bn4b8_branch2a (4, 48, 48, 256)
res4b8_branch2b (4, 48, 48, 256)
bn4b8_branch2b (4, 48, 48, 256)
res4b8_branch2c (4, 48, 48, 1024)
bn4b8_branch2c (4, 48, 48, 1024)
res4b8_relu (4, 48, 48, 1024)
res4b9_branch2a (4, 48, 48, 256)
bn4b9_branch2a (4, 48, 48, 256)
res4b9_branch2b (4, 48, 48, 256)
bn4b9_branch2b (4, 48, 48, 256)
res4b9_branch2c (4, 48, 48, 1024)
bn4b9_branch2c (4, 48, 48, 1024)
res4b9_relu (4, 48, 48, 1024)
res4b10_branch2a (4, 48, 48, 256)
bn4b10_branch2a (4, 48, 48, 256)
res4b10_branch2b (4, 48, 48, 256)
bn4b10_branch2b (4, 48, 48, 256)
res4b10_branch2c (4, 48, 48, 1024)
bn4b10_branch2c (4, 48, 48, 1024)
res4b10_relu (4, 48, 48, 1024)
res4b11_branch2a (4, 48, 48, 256)
bn4b11_branch2a (4, 48, 48, 256)
res4b11_branch2b (4, 48, 48, 256)
bn4b11_branch2b (4, 48, 48, 256)
res4b11_branch2c (4, 48, 48, 1024)
bn4b11_branch2c (4, 48, 48, 1024)
res4b11_relu (4, 48, 48, 1024)
res4b12_branch2a (4, 48, 48, 256)
bn4b12_branch2a (4, 48, 48, 256)
res4b12_branch2b (4, 48, 48, 256)
bn4b12_branch2b (4, 48, 48, 256)
res4b12_branch2c (4, 48, 48, 1024)
bn4b12_branch2c (4, 48, 48, 1024)
res4b12_relu (4, 48, 48, 1024)
res4b13_branch2a (4, 48, 48, 256)
bn4b13_branch2a (4, 48, 48, 256)
res4b13_branch2b (4, 48, 48, 256)
bn4b13_branch2b (4, 48, 48, 256)
res4b13_branch2c (4, 48, 48, 1024)
bn4b13_branch2c (4, 48, 48, 1024)
res4b13_relu (4, 48, 48, 1024)
res4b14_branch2a (4, 48, 48, 256)
bn4b14_branch2a (4, 48, 48, 256)
res4b14_branch2b (4, 48, 48, 256)
bn4b14_branch2b (4, 48, 48, 256)
res4b14_branch2c (4, 48, 48, 1024)
bn4b14_branch2c (4, 48, 48, 1024)
res4b14_relu (4, 48, 48, 1024)
res4b15_branch2a (4, 48, 48, 256)
bn4b15_branch2a (4, 48, 48, 256)
res4b15_branch2b (4, 48, 48, 256)
bn4b15_branch2b (4, 48, 48, 256)
res4b15_branch2c (4, 48, 48, 1024)
bn4b15_branch2c (4, 48, 48, 1024)
res4b15_relu (4, 48, 48, 1024)
res4b16_branch2a (4, 48, 48, 256)
bn4b16_branch2a (4, 48, 48, 256)
res4b16_branch2b (4, 48, 48, 256)
bn4b16_branch2b (4, 48, 48, 256)
res4b16_branch2c (4, 48, 48, 1024)
bn4b16_branch2c (4, 48, 48, 1024)
res4b16_relu (4, 48, 48, 1024)
res4b17_branch2a (4, 48, 48, 256)
bn4b17_branch2a (4, 48, 48, 256)
res4b17_branch2b (4, 48, 48, 256)
bn4b17_branch2b (4, 48, 48, 256)
res4b17_branch2c (4, 48, 48, 1024)
bn4b17_branch2c (4, 48, 48, 1024)
res4b17_relu (4, 48, 48, 1024)
res4b18_branch2a (4, 48, 48, 256)
bn4b18_branch2a (4, 48, 48, 256)
res4b18_branch2b (4, 48, 48, 256)
bn4b18_branch2b (4, 48, 48, 256)
res4b18_branch2c (4, 48, 48, 1024)
bn4b18_branch2c (4, 48, 48, 1024)
res4b18_relu (4, 48, 48, 1024)
res4b19_branch2a (4, 48, 48, 256)
bn4b19_branch2a (4, 48, 48, 256)
res4b19_branch2b (4, 48, 48, 256)
bn4b19_branch2b (4, 48, 48, 256)
res4b19_branch2c (4, 48, 48, 1024)
bn4b19_branch2c (4, 48, 48, 1024)
res4b19_relu (4, 48, 48, 1024)
res4b20_branch2a (4, 48, 48, 256)
bn4b20_branch2a (4, 48, 48, 256)
res4b20_branch2b (4, 48, 48, 256)
bn4b20_branch2b (4, 48, 48, 256)
res4b20_branch2c (4, 48, 48, 1024)
bn4b20_branch2c (4, 48, 48, 1024)
res4b20_relu (4, 48, 48, 1024)
res4b21_branch2a (4, 48, 48, 256)
bn4b21_branch2a (4, 48, 48, 256)
res4b21_branch2b (4, 48, 48, 256)
bn4b21_branch2b (4, 48, 48, 256)
res4b21_branch2c (4, 48, 48, 1024)
bn4b21_branch2c (4, 48, 48, 1024)
res4b21_relu (4, 48, 48, 1024)
res4b22_branch2a (4, 48, 48, 256)
bn4b22_branch2a (4, 48, 48, 256)
res4b22_branch2b (4, 48, 48, 256)
bn4b22_branch2b (4, 48, 48, 256)
res4b22_branch2c (4, 48, 48, 1024)
bn4b22_branch2c (4, 48, 48, 1024)
res4b22_relu (4, 48, 48, 1024)
res5a_branch1 (4, 48, 48, 2048)
bn5a_branch1 (4, 48, 48, 2048)
res5a_branch2a (4, 48, 48, 512)
bn5a_branch2a (4, 48, 48, 512)
res5a_branch2b (4, 48, 48, 512)
bn5a_branch2b (4, 48, 48, 512)
res5a_branch2c (4, 48, 48, 2048)
bn5a_branch2c (4, 48, 48, 2048)
res5a_relu (4, 48, 48, 2048)
res5b_branch2a (4, 48, 48, 512)
bn5b_branch2a (4, 48, 48, 512)
res5b_branch2b (4, 48, 48, 512)
bn5b_branch2b (4, 48, 48, 512)
res5b_branch2c (4, 48, 48, 2048)
bn5b_branch2c (4, 48, 48, 2048)
res5b_relu (4, 48, 48, 2048)
res5c_branch2a (4, 48, 48, 512)
bn5c_branch2a (4, 48, 48, 512)
res5c_branch2b (4, 48, 48, 512)
bn5c_branch2b (4, 48, 48, 512)
res5c_branch2c (4, 48, 48, 2048)
bn5c_branch2c (4, 48, 48, 2048)
res5c_relu (4, 48, 48, 2048)
fc1_voc12_c0 (4, 48, 48, 2)
fc1_voc12_c1 (4, 48, 48, 2)
fc1_voc12_c2 (4, 48, 48, 2)
fc1_voc12_c3 (4, 48, 48, 2)
因为单张1080显卡现存有限,我使用的batch设置为4,输入图片是384*384大小的,但是第一个卷积的步长设置为了2
conv1 (4, 192, 192, 64)
bn_conv1 (4, 192, 192, 64)
pool1 (4, 192, 192, 64)
res2a_branch1 (4, 96, 96, 256)
bn2a_branch1 (4, 96, 96, 256)
res2a_branch2a (4, 96, 96, 64)
bn2a_branch2a (4, 96, 96, 64)
res2a_branch2b (4, 96, 96, 64)
bn2a_branch2b (4, 96, 96, 64)
res2a_branch2c (4, 96, 96, 256)
bn2a_branch2c (4, 96, 96, 256)
res2a_relu (4, 96, 96, 256)
res2b_branch2a (4, 96, 96, 64)
bn2b_branch2a (4, 96, 96, 64)
res2b_branch2b (4, 96, 96, 64)
bn2b_branch2b (4, 96, 96, 64)
res2b_branch2c (4, 96, 96, 256)
bn2b_branch2c (4, 96, 96, 256)
res2b_relu (4, 96, 96, 256)
res2c_branch2a (4, 96, 96, 64)
bn2c_branch2a (4, 96, 96, 64)
res2c_branch2b (4, 96, 96, 64)
bn2c_branch2b (4, 96, 96, 64)
res2c_branch2c (4, 96, 96, 256)
bn2c_branch2c (4, 96, 96, 256)
res2c_relu (4, 96, 96, 256)
res3a_branch1 (4, 48, 48, 512)
bn3a_branch1 (4, 48, 48, 512)
res3a_branch2a (4, 48, 48, 128)
bn3a_branch2a (4, 48, 48, 128)
res3a_branch2b (4, 48, 48, 128)
bn3a_branch2b (4, 48, 48, 128)
res3a_branch2c (4, 48, 48, 512)
bn3a_branch2c (4, 48, 48, 512)
res3a_relu (4, 48, 48, 512)
res3b1_branch2a (4, 48, 48, 128)
bn3b1_branch2a (4, 48, 48, 128)
res3b1_branch2b (4, 48, 48, 128)
bn3b1_branch2b (4, 48, 48, 128)
res3b1_branch2c (4, 48, 48, 512)
bn3b1_branch2c (4, 48, 48, 512)
res3b1_relu (4, 48, 48, 512)
res3b2_branch2a (4, 48, 48, 128)
bn3b2_branch2a (4, 48, 48, 128)
res3b2_branch2b (4, 48, 48, 128)
bn3b2_branch2b (4, 48, 48, 128)
res3b2_branch2c (4, 48, 48, 512)
bn3b2_branch2c (4, 48, 48, 512)
res3b2_relu (4, 48, 48, 512)
res3b3_branch2a (4, 48, 48, 128)
bn3b3_branch2a (4, 48, 48, 128)
res3b3_branch2b (4, 48, 48, 128)
bn3b3_branch2b (4, 48, 48, 128)
res3b3_branch2c (4, 48, 48, 512)
bn3b3_branch2c (4, 48, 48, 512)
res3b3_relu (4, 48, 48, 512)
res4a_branch1 (4, 48, 48, 1024)
bn4a_branch1 (4, 48, 48, 1024)
res4a_branch2a (4, 48, 48, 256)
bn4a_branch2a (4, 48, 48, 256)
res4a_branch2b (4, 48, 48, 256)
bn4a_branch2b (4, 48, 48, 256)
res4a_branch2c (4, 48, 48, 1024)
bn4a_branch2c (4, 48, 48, 1024)
res4a_relu (4, 48, 48, 1024)
res4b1_branch2a (4, 48, 48, 256)
bn4b1_branch2a (4, 48, 48, 256)
res4b1_branch2b (4, 48, 48, 256)
bn4b1_branch2b (4, 48, 48, 256)
res4b1_branch2c (4, 48, 48, 1024)
bn4b1_branch2c (4, 48, 48, 1024)
res4b1_relu (4, 48, 48, 1024)
res4b2_branch2a (4, 48, 48, 256)
bn4b2_branch2a (4, 48, 48, 256)
res4b2_branch2b (4, 48, 48, 256)
bn4b2_branch2b (4, 48, 48, 256)
res4b2_branch2c (4, 48, 48, 1024)
bn4b2_branch2c (4, 48, 48, 1024)
res4b2_relu (4, 48, 48, 1024)
res4b3_branch2a (4, 48, 48, 256)
bn4b3_branch2a (4, 48, 48, 256)
res4b3_branch2b (4, 48, 48, 256)
bn4b3_branch2b (4, 48, 48, 256)
res4b3_branch2c (4, 48, 48, 1024)
bn4b3_branch2c (4, 48, 48, 1024)
res4b3_relu (4, 48, 48, 1024)
res4b4_branch2a (4, 48, 48, 256)
bn4b4_branch2a (4, 48, 48, 256)
res4b4_branch2b (4, 48, 48, 256)
bn4b4_branch2b (4, 48, 48, 256)
res4b4_branch2c (4, 48, 48, 1024)
bn4b4_branch2c (4, 48, 48, 1024)
res4b4_relu (4, 48, 48, 1024)
res4b5_branch2a (4, 48, 48, 256)
bn4b5_branch2a (4, 48, 48, 256)
res4b5_branch2b (4, 48, 48, 256)
bn4b5_branch2b (4, 48, 48, 256)
res4b5_branch2c (4, 48, 48, 1024)
bn4b5_branch2c (4, 48, 48, 1024)
res4b5_relu (4, 48, 48, 1024)
res4b6_branch2a (4, 48, 48, 256)
bn4b6_branch2a (4, 48, 48, 256)
res4b6_branch2b (4, 48, 48, 256)
bn4b6_branch2b (4, 48, 48, 256)
res4b6_branch2c (4, 48, 48, 1024)
bn4b6_branch2c (4, 48, 48, 1024)
res4b6_relu (4, 48, 48, 1024)
res4b7_branch2a (4, 48, 48, 256)
bn4b7_branch2a (4, 48, 48, 256)
res4b7_branch2b (4, 48, 48, 256)
bn4b7_branch2b (4, 48, 48, 256)
res4b7_branch2c (4, 48, 48, 1024)
bn4b7_branch2c (4, 48, 48, 1024)
res4b7_relu (4, 48, 48, 1024)
res4b8_branch2a (4, 48, 48, 256)
bn4b8_branch2a (4, 48, 48, 256)
res4b8_branch2b (4, 48, 48, 256)
bn4b8_branch2b (4, 48, 48, 256)
res4b8_branch2c (4, 48, 48, 1024)
bn4b8_branch2c (4, 48, 48, 1024)
res4b8_relu (4, 48, 48, 1024)
res4b9_branch2a (4, 48, 48, 256)
bn4b9_branch2a (4, 48, 48, 256)
res4b9_branch2b (4, 48, 48, 256)
bn4b9_branch2b (4, 48, 48, 256)
res4b9_branch2c (4, 48, 48, 1024)
bn4b9_branch2c (4, 48, 48, 1024)
res4b9_relu (4, 48, 48, 1024)
res4b10_branch2a (4, 48, 48, 256)
bn4b10_branch2a (4, 48, 48, 256)
res4b10_branch2b (4, 48, 48, 256)
bn4b10_branch2b (4, 48, 48, 256)
res4b10_branch2c (4, 48, 48, 1024)
bn4b10_branch2c (4, 48, 48, 1024)
res4b10_relu (4, 48, 48, 1024)
res4b11_branch2a (4, 48, 48, 256)
bn4b11_branch2a (4, 48, 48, 256)
res4b11_branch2b (4, 48, 48, 256)
bn4b11_branch2b (4, 48, 48, 256)
res4b11_branch2c (4, 48, 48, 1024)
bn4b11_branch2c (4, 48, 48, 1024)
res4b11_relu (4, 48, 48, 1024)
res4b12_branch2a (4, 48, 48, 256)
bn4b12_branch2a (4, 48, 48, 256)
res4b12_branch2b (4, 48, 48, 256)
bn4b12_branch2b (4, 48, 48, 256)
res4b12_branch2c (4, 48, 48, 1024)
bn4b12_branch2c (4, 48, 48, 1024)
res4b12_relu (4, 48, 48, 1024)
res4b13_branch2a (4, 48, 48, 256)
bn4b13_branch2a (4, 48, 48, 256)
res4b13_branch2b (4, 48, 48, 256)
bn4b13_branch2b (4, 48, 48, 256)
res4b13_branch2c (4, 48, 48, 1024)
bn4b13_branch2c (4, 48, 48, 1024)
res4b13_relu (4, 48, 48, 1024)
res4b14_branch2a (4, 48, 48, 256)
bn4b14_branch2a (4, 48, 48, 256)
res4b14_branch2b (4, 48, 48, 256)
bn4b14_branch2b (4, 48, 48, 256)
res4b14_branch2c (4, 48, 48, 1024)
bn4b14_branch2c (4, 48, 48, 1024)
res4b14_relu (4, 48, 48, 1024)
res4b15_branch2a (4, 48, 48, 256)
bn4b15_branch2a (4, 48, 48, 256)
res4b15_branch2b (4, 48, 48, 256)
bn4b15_branch2b (4, 48, 48, 256)
res4b15_branch2c (4, 48, 48, 1024)
bn4b15_branch2c (4, 48, 48, 1024)
res4b15_relu (4, 48, 48, 1024)
res4b16_branch2a (4, 48, 48, 256)
bn4b16_branch2a (4, 48, 48, 256)
res4b16_branch2b (4, 48, 48, 256)
bn4b16_branch2b (4, 48, 48, 256)
res4b16_branch2c (4, 48, 48, 1024)
bn4b16_branch2c (4, 48, 48, 1024)
res4b16_relu (4, 48, 48, 1024)
res4b17_branch2a (4, 48, 48, 256)
bn4b17_branch2a (4, 48, 48, 256)
res4b17_branch2b (4, 48, 48, 256)
bn4b17_branch2b (4, 48, 48, 256)
res4b17_branch2c (4, 48, 48, 1024)
bn4b17_branch2c (4, 48, 48, 1024)
res4b17_relu (4, 48, 48, 1024)
res4b18_branch2a (4, 48, 48, 256)
bn4b18_branch2a (4, 48, 48, 256)
res4b18_branch2b (4, 48, 48, 256)
bn4b18_branch2b (4, 48, 48, 256)
res4b18_branch2c (4, 48, 48, 1024)
bn4b18_branch2c (4, 48, 48, 1024)
res4b18_relu (4, 48, 48, 1024)
res4b19_branch2a (4, 48, 48, 256)
bn4b19_branch2a (4, 48, 48, 256)
res4b19_branch2b (4, 48, 48, 256)
bn4b19_branch2b (4, 48, 48, 256)
res4b19_branch2c (4, 48, 48, 1024)
bn4b19_branch2c (4, 48, 48, 1024)
res4b19_relu (4, 48, 48, 1024)
res4b20_branch2a (4, 48, 48, 256)
bn4b20_branch2a (4, 48, 48, 256)
res4b20_branch2b (4, 48, 48, 256)
bn4b20_branch2b (4, 48, 48, 256)
res4b20_branch2c (4, 48, 48, 1024)
bn4b20_branch2c (4, 48, 48, 1024)
res4b20_relu (4, 48, 48, 1024)
res4b21_branch2a (4, 48, 48, 256)
bn4b21_branch2a (4, 48, 48, 256)
res4b21_branch2b (4, 48, 48, 256)
bn4b21_branch2b (4, 48, 48, 256)
res4b21_branch2c (4, 48, 48, 1024)
bn4b21_branch2c (4, 48, 48, 1024)
res4b21_relu (4, 48, 48, 1024)
res4b22_branch2a (4, 48, 48, 256)
bn4b22_branch2a (4, 48, 48, 256)
res4b22_branch2b (4, 48, 48, 256)
bn4b22_branch2b (4, 48, 48, 256)
res4b22_branch2c (4, 48, 48, 1024)
bn4b22_branch2c (4, 48, 48, 1024)
res4b22_relu (4, 48, 48, 1024)
res5a_branch1 (4, 48, 48, 2048)
bn5a_branch1 (4, 48, 48, 2048)
res5a_branch2a (4, 48, 48, 512)
bn5a_branch2a (4, 48, 48, 512)
res5a_branch2b (4, 48, 48, 512)
bn5a_branch2b (4, 48, 48, 512)
res5a_branch2c (4, 48, 48, 2048)
bn5a_branch2c (4, 48, 48, 2048)
res5a_relu (4, 48, 48, 2048)
res5b_branch2a (4, 48, 48, 512)
bn5b_branch2a (4, 48, 48, 512)
res5b_branch2b (4, 48, 48, 512)
bn5b_branch2b (4, 48, 48, 512)
res5b_branch2c (4, 48, 48, 2048)
bn5b_branch2c (4, 48, 48, 2048)
res5b_relu (4, 48, 48, 2048)
res5c_branch2a (4, 48, 48, 512)
bn5c_branch2a (4, 48, 48, 512)
res5c_branch2b (4, 48, 48, 512)
bn5c_branch2b (4, 48, 48, 512)
res5c_branch2c (4, 48, 48, 2048)
bn5c_branch2c (4, 48, 48, 2048)
res5c_relu (4, 48, 48, 2048)
fc1_voc12_c0 (4, 48, 48, 2)
fc1_voc12_c1 (4, 48, 48, 2)
fc1_voc12_c2 (4, 48, 48, 2)
fc1_voc12_c3 (4, 48, 48, 2)
二、train.p y:
下面仔细介绍下代码的具体实现:
from __future__ import print_function
import argparse
from datetime import datetime
import os
import sys
import time
import tensorflow as tf
import numpy as np
from deeplab_resnet import DeepLabResNetModel, ImageReader, decode_labels, inv_preprocess, prepare_label
IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)
BATCH_SIZE = 10
DATA_DIRECTORY = '/home/VOCdevkit'
DATA_LIST_PATH = './dataset/train.txt'
IGNORE_LABEL = 255
INPUT_SIZE = '321,321'
LEARNING_RATE = 2.5e-4
MOMENTUM = 0.9
NUM_CLASSES = 21
NUM_STEPS = 20001
POWER = 0.9
RANDOM_SEED = 1234
RESTORE_FROM = './deeplab_resnet.ckpt'
SAVE_NUM_IMAGES = 2
SAVE_PRED_EVERY = 1000
SNAPSHOT_DIR = './snapshots/'
WEIGHT_DECAY = 0.0005
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="DeepLab-ResNet Network")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
help="Number of images sent to the network in one step.")
parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,
help="Path to the directory containing the PASCAL VOC dataset.")
parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH,
help="Path to the file listing the images in the dataset.")
parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
help="The index of the label to ignore during the training.")
parser.add_argument("--input-size", type=str, default=INPUT_SIZE,
help="Comma-separated string with height and width of images.")
parser.add_argument("--is-training", action="store_true",
help="Whether to updates the running means and variances during the training.")
parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE,
help="Base learning rate for training with polynomial decay.")
parser.add_argument("--momentum", type=float, default=MOMENTUM,
help="Momentum component of the optimiser.")
parser.add_argument("--not-restore-last", action="store_true",
help="Whether to not restore last (FC) layers.")
parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
help="Number of classes to predict (including background).")
parser.add_argument("--num-steps", type=int, default=NUM_STEPS,
help="Number of training steps.")
parser.add_argument("--power", type=float, default=POWER,
help="Decay parameter to compute the learning rate.")
parser.add_argument("--random-mirror", action="store_true",
help="Whether to randomly mirror the inputs during the training.")
parser.add_argument("--random-scale", action="store_true",
help="Whether to randomly scale the inputs during the training.")
parser.add_argument("--random-seed", type=int, default=RANDOM_SEED,
help="Random seed to have reproducible results.")
parser.add_argument("--restore-from", type=str, default=RESTORE_FROM,
help="Where restore model parameters from.")
parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES,
help="How many images to save.")
parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY,
help="Save summaries and checkpoint every often.")
parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,
help="Where to save snapshots of the model.")
parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY,
help="Regularisation parameter for L2-loss.")
return parser.parse_args()
#保存check point函数
def save(saver, sess, logdir, step):
'''Save weights.
Args:
saver: TensorFlow Saver object.
sess: TensorFlow session.
logdir: path to the snapshots directory.
step: current training step.
'''
#保存的名称、路径
model_name = 'model.ckpt'
checkpoint_path = os.path.join(logdir, model_name)
if not os.path.exists(logdir):
os.makedirs(logdir)
saver.save(sess, checkpoint_path, global_step=step)
print('The checkpoint has been created.')
#reload函数,从.ckpt文件恢复网络参数进行训练
def load(saver, sess, ckpt_path):
'''Load trained weights.
Args:
saver: TensorFlow Saver object.
sess: TensorFlow session.
ckpt_path: path to checkpoint file with parameters.
'''
saver.restore(sess, ckpt_path)
print("Restored model parameters from {}".format(ckpt_path))
def main():
"""Create the model and start the training."""
#使用命令行传入参数时,解析传入的参数
#如果不适用命令行调用此函数,则可以通过修改默认值来实现传入指定参数的目的
args = get_arguments()
h, w = map(int, args.input_size.split(','))
input_size = (h, w)
tf.set_random_seed(args.random_seed)
#使用tf的队列向网络喂数据,下面需要初始化一个队列
#首先建立多线程
# Create queue coordinator.
coord = tf.train.Coordinator()
# Load reader.
with tf.name_scope("create_inputs"):
reader = ImageReader(
args.data_dir,
args.data_list,
input_size,
args.random_scale,
args.random_mirror,
args.ignore_label,
IMG_MEAN,
coord)
#队列的输出是image_batch, label_batch
image_batch, label_batch = reader.dequeue(args.batch_size)
# Create network.
net = DeepLabResNetModel({'data': image_batch}, is_training=args.is_training, num_classes=args.num_classes)
# For a small batch size, it is better to keep
# the statistics of the BN layers (running means and variances)
# frozen, and to not update the values provided by the pre-trained model.
# If is_training=True, the statistics will be updated during the training.
# Note that is_training=False still updates BN parameters gamma (scale) and beta (offset)
# if they are presented in var_list of the optimiser definition.
# Predictions.
raw_output = net.layers['fc1_voc12']
#确定网络中的参数,哪些需要被训练、哪些不用
# Which variables to load. Running means and variances are not trainable,
# thus all_variables() should be restored.
restore_var = [v for v in tf.global_variables() if 'fc' not in v.name or not args.not_restore_last]
all_trainable = [v for v in tf.trainable_variables() if 'beta' not in v.name and 'gamma' not in v.name]
fc_trainable = [v for v in all_trainable if 'fc' in v.name]
#被训练的参数中,卷积参数的学习率为lr,全连接层的w参数学习率为10*lr,全连接层的b参数学习率为20×lr
conv_trainable = [v for v in all_trainable if 'fc' not in v.name] # lr * 1.0
fc_w_trainable = [v for v in fc_trainable if 'weights' in v.name] # lr * 10.0
fc_b_trainable = [v for v in fc_trainable if 'biases' in v.name] # lr * 20.0
assert(len(all_trainable) == len(fc_trainable) + len(conv_trainable))
assert(len(fc_trainable) == len(fc_w_trainable) + len(fc_b_trainable))
# Predictions: ignoring all predictions with labels greater or equal than n_classes
#将网络的输出reshape成[-1, args.num_classes]
raw_prediction = tf.reshape(raw_output, [-1, args.num_classes])
#由于我们使用的是sparse_softmax_cross_entropy_with_logits函数,
#所以将label的维度修改成 [batch_size, h, w],也就是说去掉channel这个维度,
#如果使用的是softmax_cross_entropy_with_logits,需要将one-hot位置设置为true,以便将label转换为 [batch_size, h, w,num_classes]
label_proc = prepare_label(label_batch, tf.stack(raw_output.get_shape()[1:3]), num_classes=args.num_classes, one_hot=False) # [batch_size, h, w]
#变为与raw_prediction同样的形状
raw_gt = tf.reshape(label_proc, [-1,])
#去掉label中超过num_classes的值
indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, args.num_classes - 1)), 1)
gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
prediction = tf.gather(raw_prediction, indices)
# Pixel-wise softmax loss.
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=prediction, labels=gt)
l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name]
reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses)
# Processed predictions: for visualisation.
raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(image_batch)[1:3,])
raw_output_up = tf.argmax(raw_output_up, dimension=3)
pred = tf.expand_dims(raw_output_up, dim=3)
# Image summary.
images_summary = tf.py_func(inv_preprocess, [image_batch, args.save_num_images, IMG_MEAN], tf.uint8)
labels_summary = tf.py_func(decode_labels, [label_batch, args.save_num_images, args.num_classes], tf.uint8)
preds_summary = tf.py_func(decode_labels, [pred, args.save_num_images, args.num_classes], tf.uint8)
total_summary = tf.summary.image('images',
tf.concat(axis=2, values=[images_summary, labels_summary, preds_summary]),
max_outputs=args.save_num_images) # Concatenate row-wise.
summary_writer = tf.summary.FileWriter(args.snapshot_dir,
graph=tf.get_default_graph())
# Define loss and optimisation parameters.
base_lr = tf.constant(args.learning_rate)
step_ph = tf.placeholder(dtype=tf.float32, shape=())
#设置学习率递减
#我觉得这段程序获益匪浅,因为以前对学习率的设置都是一个 tf.train.MomentumOptimizer设个初值就完事儿了
#作者对不同的参数的学习率进行了不同的设计,还设计了学习率的递减策略,及如何把递减策略应用到网络中去
learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))
opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
opt_fc_w = tf.train.MomentumOptimizer(learning_rate * 10.0, args.momentum)
opt_fc_b = tf.train.MomentumOptimizer(learning_rate * 20.0, args.momentum)
grads = tf.gradients(reduced_loss, conv_trainable + fc_w_trainable + fc_b_trainable)
grads_conv = grads[:len(conv_trainable)]
grads_fc_w = grads[len(conv_trainable) : (len(conv_trainable) + len(fc_w_trainable))]
grads_fc_b = grads[(len(conv_trainable) + len(fc_w_trainable)):]
train_op_conv = opt_conv.apply_gradients(zip(grads_conv, conv_trainable))
train_op_fc_w = opt_fc_w.apply_gradients(zip(grads_fc_w, fc_w_trainable))
train_op_fc_b = opt_fc_b.apply_gradients(zip(grads_fc_b, fc_b_trainable))
train_op = tf.group(train_op_conv, train_op_fc_w, train_op_fc_b)
# Set up tf session and initialize variables.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
init = tf.global_variables_initializer()
sess.run(init)
# Saver for storing checkpoints of the model.
saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)
# Load variables if the checkpoint is provided.
if args.restore_from is not None:
loader = tf.train.Saver(var_list=restore_var)
load(loader, sess, args.restore_from)
# Start queue threads.
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
# Iterate over training steps.
for step in range(args.num_steps):
start_time = time.time()
feed_dict = { step_ph : step }
if step % args.save_pred_every == 0:
loss_value, images, labels, preds, summary, _ = sess.run([reduced_loss, image_batch, label_batch, pred, total_summary, train_op], feed_dict=feed_dict)
summary_writer.add_summary(summary, step)
save(saver, sess, args.snapshot_dir, step)
else:
loss_value, _ = sess.run([reduced_loss, train_op], feed_dict=feed_dict)
duration = time.time() - start_time
print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, duration))
coord.request_stop()
coord.join(threads)
#程序入口
if __name__ == '__main__':
main()
from __future__ import print_function
import argparse
from datetime import datetime
import os
import sys
import time
import tensorflow as tf
import numpy as np
from deeplab_resnet import DeepLabResNetModel, ImageReader, decode_labels, inv_preprocess, prepare_label
IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)
BATCH_SIZE = 10
DATA_DIRECTORY = '/home/VOCdevkit'
DATA_LIST_PATH = './dataset/train.txt'
IGNORE_LABEL = 255
INPUT_SIZE = '321,321'
LEARNING_RATE = 2.5e-4
MOMENTUM = 0.9
NUM_CLASSES = 21
NUM_STEPS = 20001
POWER = 0.9
RANDOM_SEED = 1234
RESTORE_FROM = './deeplab_resnet.ckpt'
SAVE_NUM_IMAGES = 2
SAVE_PRED_EVERY = 1000
SNAPSHOT_DIR = './snapshots/'
WEIGHT_DECAY = 0.0005
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="DeepLab-ResNet Network")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
help="Number of images sent to the network in one step.")
parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,
help="Path to the directory containing the PASCAL VOC dataset.")
parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH,
help="Path to the file listing the images in the dataset.")
parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
help="The index of the label to ignore during the training.")
parser.add_argument("--input-size", type=str, default=INPUT_SIZE,
help="Comma-separated string with height and width of images.")
parser.add_argument("--is-training", action="store_true",
help="Whether to updates the running means and variances during the training.")
parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE,
help="Base learning rate for training with polynomial decay.")
parser.add_argument("--momentum", type=float, default=MOMENTUM,
help="Momentum component of the optimiser.")
parser.add_argument("--not-restore-last", action="store_true",
help="Whether to not restore last (FC) layers.")
parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
help="Number of classes to predict (including background).")
parser.add_argument("--num-steps", type=int, default=NUM_STEPS,
help="Number of training steps.")
parser.add_argument("--power", type=float, default=POWER,
help="Decay parameter to compute the learning rate.")
parser.add_argument("--random-mirror", action="store_true",
help="Whether to randomly mirror the inputs during the training.")
parser.add_argument("--random-scale", action="store_true",
help="Whether to randomly scale the inputs during the training.")
parser.add_argument("--random-seed", type=int, default=RANDOM_SEED,
help="Random seed to have reproducible results.")
parser.add_argument("--restore-from", type=str, default=RESTORE_FROM,
help="Where restore model parameters from.")
parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES,
help="How many images to save.")
parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY,
help="Save summaries and checkpoint every often.")
parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,
help="Where to save snapshots of the model.")
parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY,
help="Regularisation parameter for L2-loss.")
return parser.parse_args()
#保存check point函数
def save(saver, sess, logdir, step):
'''Save weights.
Args:
saver: TensorFlow Saver object.
sess: TensorFlow session.
logdir: path to the snapshots directory.
step: current training step.
'''
#保存的名称、路径
model_name = 'model.ckpt'
checkpoint_path = os.path.join(logdir, model_name)
if not os.path.exists(logdir):
os.makedirs(logdir)
saver.save(sess, checkpoint_path, global_step=step)
print('The checkpoint has been created.')
#reload函数,从.ckpt文件恢复网络参数进行训练
def load(saver, sess, ckpt_path):
'''Load trained weights.
Args:
saver: TensorFlow Saver object.
sess: TensorFlow session.
ckpt_path: path to checkpoint file with parameters.
'''
saver.restore(sess, ckpt_path)
print("Restored model parameters from {}".format(ckpt_path))
def main():
"""Create the model and start the training."""
#使用命令行传入参数时,解析传入的参数
#如果不适用命令行调用此函数,则可以通过修改默认值来实现传入指定参数的目的
args = get_arguments()
h, w = map(int, args.input_size.split(','))
input_size = (h, w)
tf.set_random_seed(args.random_seed)
#使用tf的队列向网络喂数据,下面需要初始化一个队列
#首先建立多线程
# Create queue coordinator.
coord = tf.train.Coordinator()
# Load reader.
with tf.name_scope("create_inputs"):
reader = ImageReader(
args.data_dir,
args.data_list,
input_size,
args.random_scale,
args.random_mirror,
args.ignore_label,
IMG_MEAN,
coord)
#队列的输出是image_batch, label_batch
image_batch, label_batch = reader.dequeue(args.batch_size)
# Create network.
net = DeepLabResNetModel({'data': image_batch}, is_training=args.is_training, num_classes=args.num_classes)
# For a small batch size, it is better to keep
# the statistics of the BN layers (running means and variances)
# frozen, and to not update the values provided by the pre-trained model.
# If is_training=True, the statistics will be updated during the training.
# Note that is_training=False still updates BN parameters gamma (scale) and beta (offset)
# if they are presented in var_list of the optimiser definition.
# Predictions.
raw_output = net.layers['fc1_voc12']
#确定网络中的参数,哪些需要被训练、哪些不用
# Which variables to load. Running means and variances are not trainable,
# thus all_variables() should be restored.
restore_var = [v for v in tf.global_variables() if 'fc' not in v.name or not args.not_restore_last]
all_trainable = [v for v in tf.trainable_variables() if 'beta' not in v.name and 'gamma' not in v.name]
fc_trainable = [v for v in all_trainable if 'fc' in v.name]
#被训练的参数中,卷积参数的学习率为lr,全连接层的w参数学习率为10*lr,全连接层的b参数学习率为20×lr
conv_trainable = [v for v in all_trainable if 'fc' not in v.name] # lr * 1.0
fc_w_trainable = [v for v in fc_trainable if 'weights' in v.name] # lr * 10.0
fc_b_trainable = [v for v in fc_trainable if 'biases' in v.name] # lr * 20.0
assert(len(all_trainable) == len(fc_trainable) + len(conv_trainable))
assert(len(fc_trainable) == len(fc_w_trainable) + len(fc_b_trainable))
# Predictions: ignoring all predictions with labels greater or equal than n_classes
#将网络的输出reshape成[-1, args.num_classes]
raw_prediction = tf.reshape(raw_output, [-1, args.num_classes])
#由于我们使用的是sparse_softmax_cross_entropy_with_logits函数,
#所以将label的维度修改成 [batch_size, h, w],也就是说去掉channel这个维度,
#如果使用的是softmax_cross_entropy_with_logits,需要将one-hot位置设置为true,以便将label转换为 [batch_size, h, w,num_classes]
label_proc = prepare_label(label_batch, tf.stack(raw_output.get_shape()[1:3]), num_classes=args.num_classes, one_hot=False) # [batch_size, h, w]
#变为与raw_prediction同样的形状
raw_gt = tf.reshape(label_proc, [-1,])
#去掉label中超过num_classes的值
indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, args.num_classes - 1)), 1)
gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
prediction = tf.gather(raw_prediction, indices)
# Pixel-wise softmax loss.
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=prediction, labels=gt)
l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name]
reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses)
# Processed predictions: for visualisation.
raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(image_batch)[1:3,])
raw_output_up = tf.argmax(raw_output_up, dimension=3)
pred = tf.expand_dims(raw_output_up, dim=3)
# Image summary.
images_summary = tf.py_func(inv_preprocess, [image_batch, args.save_num_images, IMG_MEAN], tf.uint8)
labels_summary = tf.py_func(decode_labels, [label_batch, args.save_num_images, args.num_classes], tf.uint8)
preds_summary = tf.py_func(decode_labels, [pred, args.save_num_images, args.num_classes], tf.uint8)
total_summary = tf.summary.image('images',
tf.concat(axis=2, values=[images_summary, labels_summary, preds_summary]),
max_outputs=args.save_num_images) # Concatenate row-wise.
summary_writer = tf.summary.FileWriter(args.snapshot_dir,
graph=tf.get_default_graph())
# Define loss and optimisation parameters.
base_lr = tf.constant(args.learning_rate)
step_ph = tf.placeholder(dtype=tf.float32, shape=())
#设置学习率递减
#我觉得这段程序获益匪浅,因为以前对学习率的设置都是一个 tf.train.MomentumOptimizer设个初值就完事儿了
#作者对不同的参数的学习率进行了不同的设计,还设计了学习率的递减策略,及如何把递减策略应用到网络中去
learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))
opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
opt_fc_w = tf.train.MomentumOptimizer(learning_rate * 10.0, args.momentum)
opt_fc_b = tf.train.MomentumOptimizer(learning_rate * 20.0, args.momentum)
grads = tf.gradients(reduced_loss, conv_trainable + fc_w_trainable + fc_b_trainable)
grads_conv = grads[:len(conv_trainable)]
grads_fc_w = grads[len(conv_trainable) : (len(conv_trainable) + len(fc_w_trainable))]
grads_fc_b = grads[(len(conv_trainable) + len(fc_w_trainable)):]
train_op_conv = opt_conv.apply_gradients(zip(grads_conv, conv_trainable))
train_op_fc_w = opt_fc_w.apply_gradients(zip(grads_fc_w, fc_w_trainable))
train_op_fc_b = opt_fc_b.apply_gradients(zip(grads_fc_b, fc_b_trainable))
train_op = tf.group(train_op_conv, train_op_fc_w, train_op_fc_b)
# Set up tf session and initialize variables.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
init = tf.global_variables_initializer()
sess.run(init)
# Saver for storing checkpoints of the model.
saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)
# Load variables if the checkpoint is provided.
if args.restore_from is not None:
loader = tf.train.Saver(var_list=restore_var)
load(loader, sess, args.restore_from)
# Start queue threads.
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
# Iterate over training steps.
for step in range(args.num_steps):
start_time = time.time()
feed_dict = { step_ph : step }
if step % args.save_pred_every == 0:
loss_value, images, labels, preds, summary, _ = sess.run([reduced_loss, image_batch, label_batch, pred, total_summary, train_op], feed_dict=feed_dict)
summary_writer.add_summary(summary, step)
save(saver, sess, args.snapshot_dir, step)
else:
loss_value, _ = sess.run([reduced_loss, train_op], feed_dict=feed_dict)
duration = time.time() - start_time
print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, duration))
coord.request_stop()
coord.join(threads)
#程序入口
if __name__ == '__main__':
main()
三、image_reader.py
主要作用是初始化一个tensorflow队列,向网络喂数据。
import os
import numpy as np
import tensorflow as tf
def image_scaling(img, label):
"""
Randomly scales the images between 0.5 to 1.5 times the original size.
Args:
img: Training image to scale.
label: Segmentation mask to scale.
"""
scale = tf.random_uniform([1], minval=0.5, maxval=1.5, dtype=tf.float32, seed=None)
h_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(img)[0]), scale))
w_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(img)[1]), scale))
new_shape = tf.squeeze(tf.stack([h_new, w_new]), squeeze_dims=[1])
img = tf.image.resize_images(img, new_shape)
label = tf.image.resize_nearest_neighbor(tf.expand_dims(label, 0), new_shape)
label = tf.squeeze(label, squeeze_dims=[0])
return img, label
def image_mirroring(img, label):
"""
Randomly mirrors the images.
Args:
img: Training image to mirror.
label: Segmentation mask to mirror.
"""
distort_left_right_random = tf.random_uniform([1], 0, 1.0, dtype=tf.float32)[0]
mirror = tf.less(tf.stack([1.0, distort_left_right_random, 1.0]), 0.5)
mirror = tf.boolean_mask([0, 1, 2], mirror)
img = tf.reverse(img, mirror)
label = tf.reverse(label, mirror)
return img, label
def random_crop_and_pad_image_and_labels(image, label, crop_h, crop_w, ignore_label=255):
"""
Randomly crop and pads the input images.
Args:
image: Training image to crop/ pad.
label: Segmentation mask to crop/ pad.
crop_h: Height of cropped segment.
crop_w: Width of cropped segment.
ignore_label: Label to ignore during the training.
"""
label = tf.cast(label, dtype=tf.float32)
label = label - ignore_label # Needs to be subtracted and later added due to 0 padding.
combined = tf.concat(axis=2, values=[image, label])
image_shape = tf.shape(image)
combined_pad = tf.image.pad_to_bounding_box(combined, 0, 0, tf.maximum(crop_h, image_shape[0]), tf.maximum(crop_w, image_shape[1]))
last_image_dim = tf.shape(image)[-1]
last_label_dim = tf.shape(label)[-1]
combined_crop = tf.random_crop(combined_pad, [crop_h,crop_w,4])
img_crop = combined_crop[:, :, :last_image_dim]
label_crop = combined_crop[:, :, last_image_dim:]
label_crop = label_crop + ignore_label
label_crop = tf.cast(label_crop, dtype=tf.uint8)
# Set static shape so that tensorflow knows shape at compile time.
img_crop.set_shape((crop_h, crop_w, 3))
label_crop.set_shape((crop_h,crop_w, 1))
return img_crop, label_crop
def read_labeled_image_list(data_dir, data_list):
"""Reads txt file containing paths to images and ground truth masks.
Args:
data_dir: path to the directory with images and masks.
data_list: path to the file with lines of the form '/path/to/image /path/to/mask'.
Returns:
Two lists with all file names for images and masks, respectively.
"""
f = open(data_list, 'r')
images = []
masks = []
for line in f:
try:
image, mask = line.strip("\n").split(' ')
except ValueError: # Adhoc for test.
image = mask = line.strip("\n")
images.append(data_dir + image)
masks.append(data_dir + mask)
return images, masks
def read_images_from_disk(input_queue, input_size, random_scale, random_mirror, ignore_label, img_mean): # optional pre-processing arguments
"""Read one image and its corresponding mask with optional pre-processing.
Args:
input_queue: tf queue with paths to the image and its mask.
input_size: a tuple with (height, width) values.
If not given, return images of original size.
random_scale: whether to randomly scale the images prior
to random crop.
random_mirror: whether to randomly mirror the images prior
to random crop.
ignore_label: index of label to ignore during the training.
img_mean: vector of mean colour values.
Returns:
Two tensors: the decoded image and its mask.
"""
img_contents = tf.read_file(input_queue[0])
label_contents = tf.read_file(input_queue[1])
img = tf.image.decode_jpeg(img_contents, channels=3)
img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img)
img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32)
# Extract mean.
img -= img_mean
label = tf.image.decode_png(label_contents, channels=1)
if input_size is not None:
h, w = input_size
# Randomly scale the images and labels.
if random_scale:
img, label = image_scaling(img, label)
# Randomly mirror the images and labels.
if random_mirror:
img, label = image_mirroring(img, label)
# Randomly crops the images and labels.
img, label = random_crop_and_pad_image_and_labels(img, label, h, w, ignore_label)
return img, label
class ImageReader(object):
'''Generic ImageReader which reads images and corresponding segmentation
masks from the disk, and enqueues them into a TensorFlow queue.
'''
def __init__(self, data_dir, data_list, input_size,
random_scale, random_mirror, ignore_label, img_mean, coord):
'''Initialise an ImageReader.
Args:
data_dir: path to the directory with images and masks.
data_list: path to the file with lines of the form '/path/to/image /path/to/mask'.
input_size: a tuple with (height, width) values, to which all the images will be resized.
random_scale: whether to randomly scale the images prior to random crop.
random_mirror: whether to randomly mirror the images prior to random crop.
ignore_label: index of label to ignore during the training.
img_mean: vector of mean colour values.
coord: TensorFlow queue coordinator.
'''
self.data_dir = data_dir
self.data_list = data_list
self.input_size = input_size
self.coord = coord
#self.image_list, self.label_list是列表,表中包含了所有的image和label的列表
self.image_list, self.label_list = read_labeled_image_list(self.data_dir, self.data_list)
#self.image_list,self.label_list转为tensor,以便加入图中流动起来
self.images = tf.convert_to_tensor(self.image_list, dtype=tf.string)
self.labels = tf.convert_to_tensor(self.label_list, dtype=tf.string)
#注意传入的参数要写成列表形式
#产生一个队列每次随机产生一张图片地址
self.queue = tf.train.slice_input_producer([self.images, self.labels],
shuffle=input_size is not None) # not shuffling if it is val
#从指定地址读取图片
self.image, self.label = read_images_from_disk(self.queue, self.input_size, random_scale, random_mirror, ignore_label, img_mean)
def dequeue(self, num_elements):
'''Pack images and labels into a batch.
Args:
num_elements: the batch size.
Returns:
Two tensors of size (batch_size, h, w, {3, 1}) for images and masks.'''
image_batch, label_batch = tf.train.batch([self.image, self.label],
num_elements)
return image_batch, label_batch
import os
import numpy as np
import tensorflow as tf
def image_scaling(img, label):
"""
Randomly scales the images between 0.5 to 1.5 times the original size.
Args:
img: Training image to scale.
label: Segmentation mask to scale.
"""
scale = tf.random_uniform([1], minval=0.5, maxval=1.5, dtype=tf.float32, seed=None)
h_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(img)[0]), scale))
w_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(img)[1]), scale))
new_shape = tf.squeeze(tf.stack([h_new, w_new]), squeeze_dims=[1])
img = tf.image.resize_images(img, new_shape)
label = tf.image.resize_nearest_neighbor(tf.expand_dims(label, 0), new_shape)
label = tf.squeeze(label, squeeze_dims=[0])
return img, label
def image_mirroring(img, label):
"""
Randomly mirrors the images.
Args:
img: Training image to mirror.
label: Segmentation mask to mirror.
"""
distort_left_right_random = tf.random_uniform([1], 0, 1.0, dtype=tf.float32)[0]
mirror = tf.less(tf.stack([1.0, distort_left_right_random, 1.0]), 0.5)
mirror = tf.boolean_mask([0, 1, 2], mirror)
img = tf.reverse(img, mirror)
label = tf.reverse(label, mirror)
return img, label
def random_crop_and_pad_image_and_labels(image, label, crop_h, crop_w, ignore_label=255):
"""
Randomly crop and pads the input images.
Args:
image: Training image to crop/ pad.
label: Segmentation mask to crop/ pad.
crop_h: Height of cropped segment.
crop_w: Width of cropped segment.
ignore_label: Label to ignore during the training.
"""
label = tf.cast(label, dtype=tf.float32)
label = label - ignore_label # Needs to be subtracted and later added due to 0 padding.
combined = tf.concat(axis=2, values=[image, label])
image_shape = tf.shape(image)
combined_pad = tf.image.pad_to_bounding_box(combined, 0, 0, tf.maximum(crop_h, image_shape[0]), tf.maximum(crop_w, image_shape[1]))
last_image_dim = tf.shape(image)[-1]
last_label_dim = tf.shape(label)[-1]
combined_crop = tf.random_crop(combined_pad, [crop_h,crop_w,4])
img_crop = combined_crop[:, :, :last_image_dim]
label_crop = combined_crop[:, :, last_image_dim:]
label_crop = label_crop + ignore_label
label_crop = tf.cast(label_crop, dtype=tf.uint8)
# Set static shape so that tensorflow knows shape at compile time.
img_crop.set_shape((crop_h, crop_w, 3))
label_crop.set_shape((crop_h,crop_w, 1))
return img_crop, label_crop
def read_labeled_image_list(data_dir, data_list):
"""Reads txt file containing paths to images and ground truth masks.
Args:
data_dir: path to the directory with images and masks.
data_list: path to the file with lines of the form '/path/to/image /path/to/mask'.
Returns:
Two lists with all file names for images and masks, respectively.
"""
f = open(data_list, 'r')
images = []
masks = []
for line in f:
try:
image, mask = line.strip("\n").split(' ')
except ValueError: # Adhoc for test.
image = mask = line.strip("\n")
images.append(data_dir + image)
masks.append(data_dir + mask)
return images, masks
def read_images_from_disk(input_queue, input_size, random_scale, random_mirror, ignore_label, img_mean): # optional pre-processing arguments
"""Read one image and its corresponding mask with optional pre-processing.
Args:
input_queue: tf queue with paths to the image and its mask.
input_size: a tuple with (height, width) values.
If not given, return images of original size.
random_scale: whether to randomly scale the images prior
to random crop.
random_mirror: whether to randomly mirror the images prior
to random crop.
ignore_label: index of label to ignore during the training.
img_mean: vector of mean colour values.
Returns:
Two tensors: the decoded image and its mask.
"""
img_contents = tf.read_file(input_queue[0])
label_contents = tf.read_file(input_queue[1])
img = tf.image.decode_jpeg(img_contents, channels=3)
img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img)
img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32)
# Extract mean.
img -= img_mean
label = tf.image.decode_png(label_contents, channels=1)
if input_size is not None:
h, w = input_size
# Randomly scale the images and labels.
if random_scale:
img, label = image_scaling(img, label)
# Randomly mirror the images and labels.
if random_mirror:
img, label = image_mirroring(img, label)
# Randomly crops the images and labels.
img, label = random_crop_and_pad_image_and_labels(img, label, h, w, ignore_label)
return img, label
class ImageReader(object):
'''Generic ImageReader which reads images and corresponding segmentation
masks from the disk, and enqueues them into a TensorFlow queue.
'''
def __init__(self, data_dir, data_list, input_size,
random_scale, random_mirror, ignore_label, img_mean, coord):
'''Initialise an ImageReader.
Args:
data_dir: path to the directory with images and masks.
data_list: path to the file with lines of the form '/path/to/image /path/to/mask'.
input_size: a tuple with (height, width) values, to which all the images will be resized.
random_scale: whether to randomly scale the images prior to random crop.
random_mirror: whether to randomly mirror the images prior to random crop.
ignore_label: index of label to ignore during the training.
img_mean: vector of mean colour values.
coord: TensorFlow queue coordinator.
'''
self.data_dir = data_dir
self.data_list = data_list
self.input_size = input_size
self.coord = coord
#self.image_list, self.label_list是列表,表中包含了所有的image和label的列表
self.image_list, self.label_list = read_labeled_image_list(self.data_dir, self.data_list)
#self.image_list,self.label_list转为tensor,以便加入图中流动起来
self.images = tf.convert_to_tensor(self.image_list, dtype=tf.string)
self.labels = tf.convert_to_tensor(self.label_list, dtype=tf.string)
#注意传入的参数要写成列表形式
#产生一个队列每次随机产生一张图片地址
self.queue = tf.train.slice_input_producer([self.images, self.labels],
shuffle=input_size is not None) # not shuffling if it is val
#从指定地址读取图片
self.image, self.label = read_images_from_disk(self.queue, self.input_size, random_scale, random_mirror, ignore_label, img_mean)
def dequeue(self, num_elements):
'''Pack images and labels into a batch.
Args:
num_elements: the batch size.
Returns:
Two tensors of size (batch_size, h, w, {3, 1}) for images and masks.'''
image_batch, label_batch = tf.train.batch([self.image, self.label],
num_elements)
return image_batch, label_batch