基于ResNet的MSTAR数据集目标分类
文章目录
- 基于ResNet的MSTAR数据集目标分类
- 说在前面
- 1. MSART数据集介绍
- 2. SAR目标分类网络
- 3. ResNet代码及训练
- 4. 结尾
- 附录(代码)
说在前面
- 前两篇文章主要讨论了CNN的原理 和CNN网络模型的发展史 ,毕业设计主要采用的模型是ResNet,采用的SAR图像数据集是MSTAR数据集,代码采用Tensorflow2.2 GPU 结构,这篇文章将介绍具体实现过程。
1. MSART数据集介绍
- 本次实验采用美国公开的运动和静止目标获取与识别(Moving and Stationary Target Acquisition and Recognition, MSTAR)数据集。数据集可在百度下载,实在找不到的可以私聊我,选用标准工作条件(Standard Operating Conditions, SOC)下十类车辆目标的SAR图像作为实验数据,这些车辆目标是:2S1、BMP2、BRDM2、BTR60、BTR70、D7、T62、T72、ZIL131、ZSU23/4。每个类别的光学图像和SAR图像如下图所示:
- 各类SAR数据均包含0°~360°方位角的目标,本次实验选用17°俯仰角下拍摄的SAR图像作为训练集,15°俯仰角下拍摄的SAR图像作为测试集,这样选择可以验证特征提取方法的泛化能力。训练和测试数据的目标类型和数量在表中列出。
类别 | 训练集数量(17°) | 测试集数量(15°) |
2S1 | 299 | 274 |
BMP2 | 233 | 196 |
BRDM2 | 298 | 274 |
BTR60 | 256 | 195 |
BTR70 | 233 | 196 |
D7 | 299 | 274 |
T62 | 299 | 273 |
T72 | 232 | 196 |
ZIL131 | 299 | 274 |
ZSU23/4 | 299 | 274 |
总计 | 2747 | 2426 |
2. SAR目标分类网络
- 残差网络(Residual Network, ResNet)解决了深度卷积神经网络的退化问题,可以训练更深的网络,并且收敛更快,另一方面,神经网络在反向传播时,容易出现梯度消失或梯度爆炸,梯度消失会导致底层的参数不能得到有效更新,梯度爆炸会使梯度以指数级速度增大,造成系统不稳定,在深层网络中这种现象更明显,而ResNet通过引入跳跃连接(Shortcut Connections)很好地解决了这些问题。对于SAR目标特征提取来说,需要更深层次地提取原始SAR目标特征,就避免不了上述问题,所以本节采用ResNet作为基础网络结构。
- ResNet的基本残差单元如下图所示,在正常的卷积层旁边增加了一个恒等映射(Identity Mapping),相当于走了一个捷径,这就可以将当前的输出X直接传给下一层网络,最终所学的H(X) = F(X) + X,同时在反向传播过程中,也可以通过这条捷径直接把梯度传递给上一层,这就一定程度解决了梯度消失问题。
- 本节设计的SAR地面军事目标深度卷积神经网络结构如下图所示:
- 图中每一个残差单元是由三个卷积层和一个跳跃连接组成,每个卷积包括卷积层、BN层和ReLU激活函数层,跳跃连接处是一个1×1的卷积层,它改变的是输入特征的通道数,方便和正常卷积结果相加传送到下一层。其中BN层会对每层的输出结果进行归一化操作,最重要的作用是加速网络的收敛速度,抑制过拟合,提高网络的泛化能力。
- SAR图像在输入前一般要进行预处理,要进行统一尺寸、图像增强等,方便后面卷积神经网络提取特征,预处理方法将在实验部分阐述。将预处理后的SAR图像输入到上述网络中,按照的卷积、池化、融合等进行正向传播,最终经过全连接层的输出通过Softmax分类器,得到一个向量,其中表示每个SAR图像类别的概率,然后计算分类交叉熵损失,再进行反向传播计算梯度,更新卷积核的参数,使损失函数最小,这样循环上述过程,迭代一定次数,损失趋于稳定,就训练好该SAR目标分类网络了。
3. ResNet代码及训练
- ResNet的基本残差单元代码如下:
class BasicBlock_3(layers.Layer):
def __init__(self, filter_num, stride=1):
super(BasicBlock_3, self).__init__()
self.conv1 = layers.Conv2D(filter_num, (1, 1), strides=stride, padding='same')
self.bn1 = layers.BatchNormalization()
self.relu1 = layers.Activation('relu')
self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
self.bn2 = layers.BatchNormalization()
self.relu2 = layers.Activation('relu')
self.conv3 = layers.Conv2D(4 * filter_num, (1, 1), strides=1, padding='same')
self.bn3 = layers.BatchNormalization()
self.downsample = Sequential()
self.downsample.add(layers.Conv2D(4 * filter_num, (1, 1), strides=stride))
self.downsample.add(layers.BatchNormalization())
def call(self, inputs, training=None):
# [b, h, w, c]
out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu1(out)
out = self.conv3(out)
out = self.bn3(out)
identity = self.downsample(inputs)
output = layers.add([out, identity])
output = tf.nn.relu(output)
return output
- 设计ResNet50的代码如下
class ResNet_50(keras.Model):
def __init__(self, layer_dims, num_class=10): # [3, 4, 6, 3]
super(ResNet_50, self).__init__()
self.stem = Sequential([layers.Conv2D(64, (5,5), strides=(2, 2), padding='same'),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='same')
])
self.layer1 = self.build_resblock(64, layer_dims[0])
self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)
self.avgpool = layers.GlobalAveragePooling2D()
self.fc = layers.Dense(num_class)
def call(self, inputs, training=None):
x = self.stem(inputs)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# [b, c]
x = self.avgpool(x)
# [b. 100]
x = self.fc(x)
return x
def build_resblock(self, filter_num, blocks, stride=1):
res_block = Sequential()
res_block.add(BasicBlock_3(filter_num, stride))
for _ in range(1, blocks):
res_block.add(BasicBlock_3(filter_num, stride=1))
return res_block
- 首先将所有类别的SAR原始图像裁剪为统一尺寸128×128,利用幂函数增强(Gamma变换)来提高暗部细节,Gamma变换公式如下:
其中
- CNN网络模型代码采用Tensorflow2.2 GPU框架,实验采用了自适应学习率优化算法(Adam),学习率设置为0.001,Batch-Size设置为16,训练次数epoch设置为50。训练过程采用TensorBoard工具来监控并存储log文件。具体训练代码如下:
model = resnet50()
model.build(input_shape=(None, 128, 128, 1))
model.summary()
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/ResNet50_epoch50_Mstar_' + current_time
tb_callback = callbacks.TensorBoard(log_dir=log_dir)
model.compile(optimizer=optimizers.Adam(lr=0.0001), loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#model.fit(train_db, epochs=1, validation_data=test_db,validation_freq=1)
model.fit(train_db, epochs=50, validation_data=test_db,validation_freq=1,callbacks=[tb_callback])
model.evaluate(test_db)
- 训练的Loss曲线和Accuracy曲线如下:
- 图中红色标签代表训练集,蓝色标签代表验证集,从曲线可以看出随着epoch的增大,Loss值逐渐下降,Accuracy值逐渐升高,没有出现过拟合现象,网络训练较好,符合预期。
- 用MSTAR数据的测试集在训练好的CNN模型中进行预测,使用5.2节的分类参数评估CNN模型,通过预测结果绘制十类SAR目标的混淆矩阵如下图所示:
- 横坐标为预测类别,纵坐标为真实类别,对角线的值就是预测正确的准确率,使用Python的Seaborn库绘制混淆矩阵图。可以看到设计的CNN模型对于每个类别都有较高的识别率。
4. 结尾
- 本次毕设实验采用的ResNet网络对MSATR数据集进行训练和预测,给出了混淆矩阵和识别率,后面会采用LBP+SVM 和CLBP+SVM 两种传统的特征提取算法同样对MSTAR数据集进行提取,对比分析三种方法的效果。
附录(代码)
- 由于搭建ResNet的代码前面已经给出,在此不再重复,下面给出具体读取数据集和预处理及训练的代码。
import tensorflow as tf
import numpy as np
from ResNet import resnet50
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, callbacks
from sklearn.metrics import accuracy_score, confusion_matrix, recall_score, precision_score, f1_score, fbeta_score
from sklearn.metrics import roc_auc_score, roc_curve, auc, classification_report
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pathlib
import datetime
import seaborn as sns
def load_one_from_path_label(path, label):
images = np.zeros((1, 128, 128, 1))
labels = tf.one_hot(label, depth=10)
labels = tf.cast(labels, dtype=tf.int32)
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image)
image = tf.image.adjust_gamma(image, 0.6) # Gamma
image = tf.image.resize(image, [128, 128]) # 重设为(128, 128)
image = tf.cast(image, dtype=tf.float32) / 255.0 # 归一化到[0,1]范围
images[0, :, :, :] = image
return images, labels
def load_from_path_label2(all_image_paths, all_image_labels):
'''读取所有图片'''
image_count = len(all_image_paths)
images = np.zeros((image_count, 128, 128, 1))
labels = tf.one_hot(all_image_labels, depth=10)
labels = tf.cast(all_image_labels, dtype=tf.int32)
for i in range(0, image_count):
image = tf.io.read_file(all_image_paths[i])
image = tf.image.decode_jpeg(image)
image = tf.image.adjust_gamma(image, 0.6) # Gamma
image = tf.image.resize(image, [128, 128]) # 重设为(128, 128)
image = tf.cast(image, dtype=tf.float32) / 255.0 # 归一化到[0,1]范围
images[i, :, :, :] = image
return images, labels
def load_from_path_label(path, label):
'''读取图片'''
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image)
label = tf.one_hot(label, depth=10)
label = tf.cast(label, dtype=tf.int32)
return image, label
def preprocess(image, label):
'''图片预处理'''
image = tf.image.adjust_gamma(image, 0.6) # Gamma
image = tf.image.resize(image, [128, 128]) # 重设为(128, 128)
image = tf.cast(image, dtype=tf.float32) / 255.0 # 归一化到[0,1]范围
return image, label
def show_image(db, row, col, title, is_preprocess=True):
'''显示10个类别图片'''
plt.figure()
plt.suptitle(title, fontsize=14)
j = 0
for i, (image, label) in enumerate(db):
if j == row * col :
break
if int(tf.argmax(label)) == int(j / col) :
if is_preprocess == True :
image = image * 255
plt.subplot(row, col, j+1)
plt.title("class" + str(int(tf.argmax(label))), fontsize=8)
plt.imshow(image, cmap='gray')
plt.axis('off')
j = j + 1
plt.tight_layout()
def get_datasets(path, train=True):
'''获取数据集'''
# 获得数据集文件路径
data_path = pathlib.Path(path)
# 获得所有类别图片的路径
all_image_paths = list(data_path.glob('*/*'))
all_image_paths = [str(path1) for path1 in all_image_paths]
# 数据集图片数量
image_count = len(all_image_paths)
# 获得类别名称列表
label_names = [item.name for item in data_path.glob('*/')]
# 枚举类别名称并转化为数字标号
label_index = dict((name, index) for index, name in enumerate(label_names))
print(label_index)
print(label_names)
print(image_count)
# 获得所有数据集图片的数字标号
all_image_labels = [label_index[pathlib.Path(path).parent.name] for path in all_image_paths]
for image, label in zip(all_image_paths[:5], all_image_labels[:5]):
print(image, ' ---> ', label)
images, labels = load_from_path_label2(all_image_paths, all_image_labels)
# 建立dataset数据集
db = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
db = db.map(load_from_path_label)
if train == True:
show_image(db, 5, 5, '(Train) Raw SAR Image', False)
db = db.map(preprocess)
show_image(db, 5, 5, '(Train) Preprocessed SAR Image', True)
else:
show_image(db, 5, 5, '(Test) Raw SAR Image', False)
db = db.map(preprocess)
show_image(db, 5, 5, '(Test) Preprocessed SAR Image', True)
db = db.shuffle(1000).batch(16)
return db, images, labels, label_names
def get_train_valid_datasets(path, train=True):
'''获取数据集'''
# 获得数据集文件路径
data_path = pathlib.Path(path)
# 获得所有类别图片的路径
all_image_paths = list(data_path.glob('*/*'))
all_image_paths = [str(path1) for path1 in all_image_paths]
# 数据集图片数量
image_count = len(all_image_paths)
# 获得类别名称列表
label_names = [item.name for item in data_path.glob('*/')]
# 枚举类别名称并转化为数字标号
label_index = dict((name, index) for index, name in enumerate(label_names))
print(label_index)
print(label_names)
print(image_count)
# 获得所有数据集图片的数字标号
all_image_labels = [label_index[pathlib.Path(path).parent.name] for path in all_image_paths]
for image, label in zip(all_image_paths[:5], all_image_labels[:5]):
print(image, ' ---> ', label)
train_images, valid_images, train_labels, valid_labels = train_test_split(all_image_paths, all_image_labels, test_size = 0.2, random_state = 0)
print('train counts -----> ',len(train_images))
print('valid counts -----> ',len(valid_images))
train_db = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_db = train_db.map(load_from_path_label)
valid_db = tf.data.Dataset.from_tensor_slices((valid_images, valid_labels))
valid_db = valid_db.map(load_from_path_label)
if train == True:
show_image(train_db, 5, 5, '(Train) Raw SAR Image', False)
train_db = train_db.map(preprocess)
show_image(train_db, 5, 5, '(Train) Preprocessed SAR Image', True)
show_image(valid_db, 5, 5, '(Valid) Raw SAR Image', False)
valid_db = valid_db.map(preprocess)
show_image(valid_db, 5, 5, '(Valid) Preprocessed SAR Image', True)
train_db = train_db.shuffle(1000).batch(16)
valid_db = valid_db.shuffle(1000).batch(16)
return train_db, valid_db
def plot_confusion_matrix(matrix, class_labels, normalize=False):
'''混淆矩阵绘图'''
if normalize:
matrix = matrix.astype('float') / matrix.sum(axis=1)[:, np.newaxis] # 混淆矩阵归一化
A = np.around(matrix, decimals=5)
A = A * 100
print(A)
matrix = np.around(matrix, decimals=2)
sns.set()
f, ax = plt.subplots()
tick_marks = np.arange(0.5,10.5,1)
sns.heatmap(matrix, annot=True, cmap="Blues",ax=ax) #画热力图
ax.set_title('confusion matrix') #标题
plt.xticks(tick_marks, class_labels, rotation=45)
plt.yticks(tick_marks, class_labels, rotation=0)
ax.set_xlabel('Predict') #x轴
ax.set_ylabel('True') #y轴
plt.tight_layout()
def main():
'''main函数'''
train_db, train_images, train_labels, train_label_names = get_datasets('E:\\SARimage\\TRAIN', True)
test_db, test_images, test_labels, test_label_names = get_datasets('E:\\SARimage\\TEST', False)
# resnet 50
model = resnet50()
model.build(input_shape=(None, 128, 128, 1))
model.summary()
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/ResNet50_epoch50_Mstar_' + current_time
tb_callback = callbacks.TensorBoard(log_dir=log_dir)
model.compile(optimizer=optimizers.Adam(lr=0.0001), loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#model.fit(train_db, epochs=1, validation_data=test_db,validation_freq=1)
model.fit(train_db, epochs=50, validation_data=test_db,validation_freq=1,callbacks=[tb_callback])
model.evaluate(test_db)
model.save_weights('./checkpoint/ResNet50_epoch50_weights.ckpt')
print('save weights')
pred_labels = model.predict(test_images)
pred_labels = tf.argmax(pred_labels, axis=1)
con_matrix = confusion_matrix(test_labels, pred_labels, labels=[0,1,2,3,4,5,6,7,8,9])
print(con_matrix)
plot_confusion_matrix(con_matrix, test_label_names, normalize=True)
np.savetxt('./checkpoint/ResNet50_epoch50_confusion_matrix.txt',con_matrix)
plt.show()