基于ResNet的MSTAR数据集目标分类


文章目录

  • 基于ResNet的MSTAR数据集目标分类
  • 说在前面
  • 1. MSART数据集介绍
  • 2. SAR目标分类网络
  • 3. ResNet代码及训练
  • 4. 结尾
  • 附录(代码)


说在前面

  • 前两篇文章主要讨论了CNN的原理 和CNN网络模型的发展史 ,毕业设计主要采用的模型是ResNet,采用的SAR图像数据集是MSTAR数据集,代码采用Tensorflow2.2 GPU 结构,这篇文章将介绍具体实现过程。

1. MSART数据集介绍

  • 本次实验采用美国公开的运动和静止目标获取与识别(Moving and Stationary Target Acquisition and Recognition, MSTAR)数据集。数据集可在百度下载,实在找不到的可以私聊我,选用标准工作条件(Standard Operating Conditions, SOC)下十类车辆目标的SAR图像作为实验数据,这些车辆目标是:2S1、BMP2、BRDM2、BTR60、BTR70、D7、T62、T72、ZIL131、ZSU23/4。每个类别的光学图像和SAR图像如下图所示:

resides数据集gt resnet数据集_卷积神经网络

  • 各类SAR数据均包含0°~360°方位角的目标,本次实验选用17°俯仰角下拍摄的SAR图像作为训练集,15°俯仰角下拍摄的SAR图像作为测试集,这样选择可以验证特征提取方法的泛化能力。训练和测试数据的目标类型和数量在表中列出。

类别

训练集数量(17°)

测试集数量(15°)

2S1

299

274

BMP2

233

196

BRDM2

298

274

BTR60

256

195

BTR70

233

196

D7

299

274

T62

299

273

T72

232

196

ZIL131

299

274

ZSU23/4

299

274

总计

2747

2426

2. SAR目标分类网络

  • 残差网络(Residual Network, ResNet)解决了深度卷积神经网络的退化问题,可以训练更深的网络,并且收敛更快,另一方面,神经网络在反向传播时,容易出现梯度消失或梯度爆炸,梯度消失会导致底层的参数不能得到有效更新,梯度爆炸会使梯度以指数级速度增大,造成系统不稳定,在深层网络中这种现象更明显,而ResNet通过引入跳跃连接(Shortcut Connections)很好地解决了这些问题。对于SAR目标特征提取来说,需要更深层次地提取原始SAR目标特征,就避免不了上述问题,所以本节采用ResNet作为基础网络结构。
  • ResNet的基本残差单元如下图所示,在正常的卷积层旁边增加了一个恒等映射(Identity Mapping),相当于走了一个捷径,这就可以将当前的输出X直接传给下一层网络,最终所学的H(X) = F(X) + X,同时在反向传播过程中,也可以通过这条捷径直接把梯度传递给上一层,这就一定程度解决了梯度消失问题。

resides数据集gt resnet数据集_resides数据集gt_02

  • 本节设计的SAR地面军事目标深度卷积神经网络结构如下图所示:

resides数据集gt resnet数据集_卷积神经网络_03

  • 图中每一个残差单元是由三个卷积层和一个跳跃连接组成,每个卷积包括卷积层、BN层和ReLU激活函数层,跳跃连接处是一个1×1的卷积层,它改变的是输入特征的通道数,方便和正常卷积结果相加传送到下一层。其中BN层会对每层的输出结果进行归一化操作,最重要的作用是加速网络的收敛速度,抑制过拟合,提高网络的泛化能力。
  • SAR图像在输入前一般要进行预处理,要进行统一尺寸、图像增强等,方便后面卷积神经网络提取特征,预处理方法将在实验部分阐述。将预处理后的SAR图像输入到上述网络中,按照的卷积、池化、融合等进行正向传播,最终经过全连接层的输出通过Softmax分类器,得到一个向量,其中表示每个SAR图像类别的概率,然后计算分类交叉熵损失,再进行反向传播计算梯度,更新卷积核的参数,使损失函数最小,这样循环上述过程,迭代一定次数,损失趋于稳定,就训练好该SAR目标分类网络了。

3. ResNet代码及训练

  • ResNet的基本残差单元代码如下:
class BasicBlock_3(layers.Layer):

    def __init__(self, filter_num, stride=1):
        super(BasicBlock_3, self).__init__()

        self.conv1 = layers.Conv2D(filter_num, (1, 1), strides=stride, padding='same')
        self.bn1 = layers.BatchNormalization()
        self.relu1 = layers.Activation('relu')

        self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
        self.bn2 = layers.BatchNormalization()
        self.relu2 = layers.Activation('relu')

        self.conv3 = layers.Conv2D(4 * filter_num, (1, 1), strides=1, padding='same')
        self.bn3 = layers.BatchNormalization()

        self.downsample = Sequential()
        self.downsample.add(layers.Conv2D(4 * filter_num, (1, 1), strides=stride))
        self.downsample.add(layers.BatchNormalization())
    
    def call(self, inputs, training=None):
        
        # [b, h, w, c]
        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu1(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        identity = self.downsample(inputs)

        output = layers.add([out, identity])
        output = tf.nn.relu(output)

        return output
  • 设计ResNet50的代码如下
class ResNet_50(keras.Model):

    def __init__(self, layer_dims, num_class=10): # [3, 4, 6, 3]
        super(ResNet_50, self).__init__()

        self.stem = Sequential([layers.Conv2D(64, (5,5), strides=(2, 2), padding='same'),
                                layers.BatchNormalization(),
                                layers.Activation('relu'),
                                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='same')
                               ])   

        self.layer1 = self.build_resblock(64, layer_dims[0])
        self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
        self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
        self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)

        self.avgpool = layers.GlobalAveragePooling2D()
        self.fc = layers.Dense(num_class)

    def call(self, inputs, training=None):
        
        x = self.stem(inputs)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # [b, c]
        x = self.avgpool(x)
        # [b. 100]
        x = self.fc(x)
        return x
    
    def build_resblock(self, filter_num, blocks, stride=1):

        res_block = Sequential()

        res_block.add(BasicBlock_3(filter_num, stride))

        for _ in range(1, blocks):
            res_block.add(BasicBlock_3(filter_num, stride=1))
        
        return res_block
  • 首先将所有类别的SAR原始图像裁剪为统一尺寸128×128,利用幂函数增强(Gamma变换)来提高暗部细节,Gamma变换公式如下:
    resides数据集gt resnet数据集_深度学习_04
    其中resides数据集gt resnet数据集_深度学习_05

resides数据集gt resnet数据集_目标分类_06

  • CNN网络模型代码采用Tensorflow2.2 GPU框架,实验采用了自适应学习率优化算法(Adam),学习率设置为0.001,Batch-Size设置为16,训练次数epoch设置为50。训练过程采用TensorBoard工具来监控并存储log文件。具体训练代码如下:
model = resnet50()
model.build(input_shape=(None, 128, 128, 1))
model.summary()
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/ResNet50_epoch50_Mstar_' + current_time
tb_callback = callbacks.TensorBoard(log_dir=log_dir)

model.compile(optimizer=optimizers.Adam(lr=0.0001), loss=tf.losses.CategoricalCrossentropy(from_logits=True),
            metrics=['accuracy'])
    #model.fit(train_db, epochs=1, validation_data=test_db,validation_freq=1)
model.fit(train_db, epochs=50, validation_data=test_db,validation_freq=1,callbacks=[tb_callback])
    
model.evaluate(test_db)
  • 训练的Loss曲线和Accuracy曲线如下:
  • resides数据集gt resnet数据集_CNN_07


  • resides数据集gt resnet数据集_CNN_08

  • 图中红色标签代表训练集,蓝色标签代表验证集,从曲线可以看出随着epoch的增大,Loss值逐渐下降,Accuracy值逐渐升高,没有出现过拟合现象,网络训练较好,符合预期。
  • 用MSTAR数据的测试集在训练好的CNN模型中进行预测,使用5.2节的分类参数评估CNN模型,通过预测结果绘制十类SAR目标的混淆矩阵如下图所示:

resides数据集gt resnet数据集_CNN_09

  • 横坐标为预测类别,纵坐标为真实类别,对角线的值就是预测正确的准确率,使用Python的Seaborn库绘制混淆矩阵图。可以看到设计的CNN模型对于每个类别都有较高的识别率。

4. 结尾

  • 本次毕设实验采用的ResNet网络对MSATR数据集进行训练和预测,给出了混淆矩阵和识别率,后面会采用LBP+SVM 和CLBP+SVM 两种传统的特征提取算法同样对MSTAR数据集进行提取,对比分析三种方法的效果。

附录(代码)

  • 由于搭建ResNet的代码前面已经给出,在此不再重复,下面给出具体读取数据集和预处理及训练的代码。
import tensorflow as tf 
import numpy as np
from ResNet import resnet50
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, callbacks
from sklearn.metrics import accuracy_score, confusion_matrix, recall_score, precision_score, f1_score, fbeta_score
from sklearn.metrics import roc_auc_score, roc_curve, auc, classification_report
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pathlib
import datetime
import seaborn as sns
def load_one_from_path_label(path, label):
    images = np.zeros((1, 128, 128, 1))
    labels = tf.one_hot(label, depth=10)
    labels = tf.cast(labels, dtype=tf.int32)

    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image)
    image = tf.image.adjust_gamma(image, 0.6)  # Gamma 
    image = tf.image.resize(image, [128, 128])  # 重设为(128, 128)
    image = tf.cast(image, dtype=tf.float32) / 255.0  # 归一化到[0,1]范围
        
    images[0, :, :, :] = image
    return images, labels

def load_from_path_label2(all_image_paths, all_image_labels):
    '''读取所有图片'''
    image_count = len(all_image_paths)
    images = np.zeros((image_count, 128, 128, 1))
    labels = tf.one_hot(all_image_labels, depth=10)
    labels = tf.cast(all_image_labels, dtype=tf.int32)

    for i in range(0, image_count):
        image = tf.io.read_file(all_image_paths[i])
        image = tf.image.decode_jpeg(image)
        image = tf.image.adjust_gamma(image, 0.6)  # Gamma 
        image = tf.image.resize(image, [128, 128])  # 重设为(128, 128)
        image = tf.cast(image, dtype=tf.float32) / 255.0  # 归一化到[0,1]范围
        
        images[i, :, :, :] = image

    return images, labels


def load_from_path_label(path, label):
    '''读取图片'''
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image)
    label = tf.one_hot(label, depth=10)
    label = tf.cast(label, dtype=tf.int32)
    return image, label

def preprocess(image, label):
    '''图片预处理'''
    
    image = tf.image.adjust_gamma(image, 0.6)  # Gamma 
    image = tf.image.resize(image, [128, 128])  # 重设为(128, 128)
    image = tf.cast(image, dtype=tf.float32) / 255.0  # 归一化到[0,1]范围
    
    return image, label

def show_image(db, row, col, title, is_preprocess=True):
    '''显示10个类别图片'''
    plt.figure()
    plt.suptitle(title, fontsize=14)
    j = 0
    for i, (image, label) in enumerate(db):
        if j == row * col :
            break
        if int(tf.argmax(label)) == int(j / col) :
            if is_preprocess == True :
                image = image * 255
            plt.subplot(row, col, j+1)
            plt.title("class" + str(int(tf.argmax(label))), fontsize=8)
            plt.imshow(image, cmap='gray')
            plt.axis('off')
            j = j + 1
    plt.tight_layout()

def get_datasets(path, train=True):
    '''获取数据集'''
    # 获得数据集文件路径
    data_path = pathlib.Path(path)
    # 获得所有类别图片的路径
    all_image_paths = list(data_path.glob('*/*'))
    all_image_paths = [str(path1) for path1 in all_image_paths]
    # 数据集图片数量
    image_count = len(all_image_paths)
    # 获得类别名称列表
    label_names = [item.name for item in data_path.glob('*/')]
    # 枚举类别名称并转化为数字标号
    label_index = dict((name, index) for index, name in enumerate(label_names))
    print(label_index)
    print(label_names)
    print(image_count)
    # 获得所有数据集图片的数字标号
    all_image_labels = [label_index[pathlib.Path(path).parent.name] for path in all_image_paths]
    for image, label in zip(all_image_paths[:5], all_image_labels[:5]):
        print(image, ' --->  ', label)
    images, labels = load_from_path_label2(all_image_paths, all_image_labels)
    # 建立dataset数据集
    db = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels)) 
    db = db.map(load_from_path_label)
    if train == True:
        show_image(db, 5, 5, '(Train) Raw SAR Image', False)
        db = db.map(preprocess)
        show_image(db, 5, 5, '(Train) Preprocessed SAR Image', True)
    else:
        show_image(db, 5, 5, '(Test) Raw SAR Image', False)
        db = db.map(preprocess)
        show_image(db, 5, 5, '(Test) Preprocessed SAR Image', True)
    
    db = db.shuffle(1000).batch(16)
    return db, images, labels, label_names


def get_train_valid_datasets(path, train=True):
    '''获取数据集'''
    # 获得数据集文件路径
    data_path = pathlib.Path(path)
    # 获得所有类别图片的路径
    all_image_paths = list(data_path.glob('*/*'))
    all_image_paths = [str(path1) for path1 in all_image_paths]
    # 数据集图片数量
    image_count = len(all_image_paths)
    # 获得类别名称列表
    label_names = [item.name for item in data_path.glob('*/')]
    # 枚举类别名称并转化为数字标号
    label_index = dict((name, index) for index, name in enumerate(label_names))
    print(label_index)
    print(label_names)
    print(image_count)
    # 获得所有数据集图片的数字标号
    all_image_labels = [label_index[pathlib.Path(path).parent.name] for path in all_image_paths]
    for image, label in zip(all_image_paths[:5], all_image_labels[:5]):
        print(image, ' --->  ', label)
    
    train_images, valid_images, train_labels, valid_labels = train_test_split(all_image_paths, all_image_labels, test_size = 0.2, random_state = 0)

    print('train counts -----> ',len(train_images))
    print('valid counts -----> ',len(valid_images))
    train_db = tf.data.Dataset.from_tensor_slices((train_images, train_labels)) 
    train_db = train_db.map(load_from_path_label)

    valid_db = tf.data.Dataset.from_tensor_slices((valid_images, valid_labels)) 
    valid_db = valid_db.map(load_from_path_label)

    if train == True:
        show_image(train_db, 5, 5, '(Train) Raw SAR Image', False)
        train_db = train_db.map(preprocess)
        show_image(train_db, 5, 5, '(Train) Preprocessed SAR Image', True)

        show_image(valid_db, 5, 5, '(Valid) Raw SAR Image', False)
        valid_db = valid_db.map(preprocess)
        show_image(valid_db, 5, 5, '(Valid) Preprocessed SAR Image', True)
    
    train_db = train_db.shuffle(1000).batch(16)
    valid_db = valid_db.shuffle(1000).batch(16)
    return train_db, valid_db

def plot_confusion_matrix(matrix, class_labels, normalize=False):
    '''混淆矩阵绘图'''
    if normalize:
        matrix = matrix.astype('float') / matrix.sum(axis=1)[:, np.newaxis] # 混淆矩阵归一化  
        A = np.around(matrix, decimals=5)
        A = A * 100
        print(A)
        matrix = np.around(matrix, decimals=2)
    sns.set()
    f, ax = plt.subplots()
    tick_marks = np.arange(0.5,10.5,1)
    sns.heatmap(matrix, annot=True, cmap="Blues",ax=ax) #画热力图
    ax.set_title('confusion matrix') #标题
    
    plt.xticks(tick_marks, class_labels, rotation=45)
    plt.yticks(tick_marks, class_labels, rotation=0)
    ax.set_xlabel('Predict') #x轴
    ax.set_ylabel('True') #y轴
    plt.tight_layout()

def main():    
    '''main函数'''
    train_db, train_images, train_labels, train_label_names = get_datasets('E:\\SARimage\\TRAIN', True)
    test_db, test_images, test_labels, test_label_names = get_datasets('E:\\SARimage\\TEST', False)
    # resnet 50
    model = resnet50()
    model.build(input_shape=(None, 128, 128, 1))
    model.summary()
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = 'logs/ResNet50_epoch50_Mstar_' + current_time
    tb_callback = callbacks.TensorBoard(log_dir=log_dir)

    model.compile(optimizer=optimizers.Adam(lr=0.0001), loss=tf.losses.CategoricalCrossentropy(from_logits=True),
            metrics=['accuracy'])
    #model.fit(train_db, epochs=1, validation_data=test_db,validation_freq=1)

    model.fit(train_db, epochs=50, validation_data=test_db,validation_freq=1,callbacks=[tb_callback])
    model.evaluate(test_db)

    model.save_weights('./checkpoint/ResNet50_epoch50_weights.ckpt')
    print('save weights')

    pred_labels = model.predict(test_images)
    pred_labels = tf.argmax(pred_labels, axis=1)
    con_matrix = confusion_matrix(test_labels, pred_labels, labels=[0,1,2,3,4,5,6,7,8,9])
    print(con_matrix)
    plot_confusion_matrix(con_matrix, test_label_names, normalize=True)
    np.savetxt('./checkpoint/ResNet50_epoch50_confusion_matrix.txt',con_matrix)
    plt.show()