模型训练及测试
一、在DeepLabv3+模型的基础上,主要需要修改以下两个文件
data_generator.py
train_utils.py
(1)添加数据集描述
_CAMVID_INFORMATION = DatasetDescriptor( splits_to_sizes={ 'train': 1035, 'val': 31,}, num_classes=3, ignore_label=255, )
(2)注册数据集
_DATASETS_INFORMATION = { 'cityscapes': _CITYSCAPES_INFORMATION, 'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION, 'ade20k': _ADE20K_INFORMATION, 'camvid':_CAMVID_INFORMATION, # 'mydata':_MYDATA_INFORMATION, }
(3)修改train_utils.py
对应的utils/train_utils.py中,将210行关于 exclude_list 的设置修改,作用是在使用预训练权重时候,不加载该 logit 层:
exclude_list = ['global_step','logits'] if not initialize_last_layer: exclude_list.extend(last_layers)
如果想在DeepLab的基础上fifine-tune其他数据集, 可在deeplab/train.py中修改输入参数。
二、网路训练
(1)下载预训练模型
下载地址:https://github.com/tensorflflow/models/blob/master/research/deeplab/g3doc/model_zoo.md
/lwh/models/research/deeplab/deeplabv3_cityscapes_train
(2)类别不平衡修正
flags.DEFINE_multi_float( 'label_weights', [1.0,3.0,3.0], 'A list of label weights, each element represents the weight for the label ' 'of its index, for example, label_weights = [0.1, 0.5] means the weight ' 'for label 0 is 0.1 and the weight for label 1 is 0.5. If set as None, all ' 'the labels have the same weight 1.0.')
(3)训练
python train.py --training_number_of_steps=30000 --train_split="train" --model_variant="xception_65"
--atrous_rates=6 --atrous_rates=12 --atrous_rates=18 --output_stride=16 --decoder_output_stride=4
--train_crop_size=801,801 --train_batch_size=2 --dataset="camvid"
--tf_initial_checkpoint='/lwh/models/research/deeplab/deeplabv3_cityscapes_train/model.ckpt'
--train_logdir='/lwh/models/research/deeplab/exp/blackboard_train/train'
--dataset_dir='/lwh/models/research/deeplab/datasets/blackboard/tfrecord'
设置train_crop_size原则:
(4)模型导出
python export_model.py \ --logtostderr \ --checkpoint_path="/lwh/models/research/deeplab/exp/blackboard_train/train/model.ckpt-30000" \ --export_path="/lwh/models/research/deeplab/exp/blackboard_train/train/frozen_inference_graph.pb" \ --model_variant="xception_65" \ --atrous_rates=6 \ --atrous_rates=12 \ --atrous_rates=18 \ --output_stride=16 \ --decoder_output_stride=4 \ --num_classes=3 \ --crop_size=1080 \ --crop_size=1920 \ --inference_scales=1.0
注意几点:
--checkpoint_path 为自己模型保存的路径
--export_path 模型导出保存的路径
--num_classes=3 自己数据的类别数包含背景
--crop_size=1080 第一个为模型要求输入的高h
--crop_size=1920 第一个为模型要求输入的宽w
三、模型测试
直接上代码
# !--*-- coding:utf-8 --*-- # Deeplab Demo import os import tarfile from matplotlib import gridspec import matplotlib.pyplot as plt import numpy as np from PIL import Image import tempfile from six.moves import urllib import tensorflow as tf class DeepLabModel(object): """ 加载 DeepLab 模型; 推断 Inference """ INPUT_TENSOR_NAME = 'ImageTensor:0' OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' INPUT_SIZE = 1920 FROZEN_GRAPH_NAME = 'frozen_inference_graph' def __init__(self, tarball_path): """ Creates and loads pretrained deeplab model. """ self.graph = tf.Graph() graph_def = None graph_def = tf.GraphDef.FromString(open(tarball_path, 'rb').read()) if graph_def is None: raise RuntimeError('Cannot find inference graph in tar archive.') with self.graph.as_default(): tf.import_graph_def(graph_def, name='') self.sess = tf.Session(graph=self.graph) def run(self, image): """ Runs inference on a single image. Args: image: A PIL.Image object, raw input image. Returns: resized_image: RGB image resized from original input image. seg_map: Segmentation map of `resized_image`. """ width, height = image.size resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) target_size = (int(resize_ratio * width), int(resize_ratio * height)) target_size = (1920,1080) resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS) print(resized_image) batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME, feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) seg_map = batch_seg_map[0] return resized_image, seg_map def create_pascal_label_colormap(): """ Creates a label colormap used in PASCAL VOC segmentation benchmark. Returns: A Colormap for visualizing segmentation results. """ colormap = np.zeros((256, 3), dtype=int) ind = np.arange(256, dtype=int) for shift in reversed(range(8)): for channel in range(3): colormap[:, channel] |= ((ind >> channel) & 1) << shift ind >>= 3 return colormap def label_to_color_image(label): """ Adds color defined by the dataset colormap to the label. Args: label: A 2D array with integer type, storing the segmentation label. Returns: result: A 2D array with floating type. The element of the array is the color indexed by the corresponding element in the input label to the PASCAL color map. Raises: ValueError: If label is not of rank 2 or its value is larger than color map maximum entry. """ if label.ndim != 2: raise ValueError('Expect 2-D input label') colormap = create_pascal_label_colormap() if np.max(label) >= len(colormap): raise ValueError('label value too large.') return colormap[label] def vis_segmentation(image, seg_map): """Visualizes input image, segmentation map and overlay view.""" plt.figure(figsize=(15, 5)) grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1]) plt.subplot(grid_spec[0]) plt.imshow(image) plt.axis('off') plt.title('input image') plt.subplot(grid_spec[1]) seg_image = label_to_color_image(seg_map).astype(np.uint8) plt.imshow(seg_image) plt.axis('off') plt.title('segmentation map') plt.subplot(grid_spec[2]) plt.imshow(image) plt.imshow(seg_image, alpha=0.7) plt.axis('off') plt.title('segmentation overlay') unique_labels = np.unique(seg_map) ax = plt.subplot(grid_spec[3]) plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest') ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([], []) ax.tick_params(width=0.0) plt.grid('off') plt.show() LABEL_NAMES = np.asarray( ['background', 'blackboard','screen']) # LABEL_NAMES = np.asarray( # ['background', 'blackboard','screen']) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) download_path = r"D:\python_project\deeplabv3+\blackboard_v2.pb" MODEL = DeepLabModel(download_path) print('model loaded successfully!') ## def run_visualization(imagefile): """ DeepLab 语义分割,并可视化结果. """ orignal_im = Image.open(imagefile) print('running deeplab on image %s...' % imagefile) resized_im, seg_map = MODEL.run(orignal_im) print(seg_map.shape) vis_segmentation(resized_im, seg_map) images_dir = r'D:\python_project\deeplabv3+\test_img' # 测试图片目录所在位置 images = sorted(os.listdir(images_dir)) for imgfile in images: run_visualization(os.path.join(images_dir, imgfile)) print('Done.')
需要注意的两点:
1.images_dir 修改为自己存图片的dir
2.INPUT_SIZE = 1920修改自己图片的hw最大的一个
测试结果展示