PaddleRS:高光谱卫星影像场景分类
使用PaddleRS对天宫二号高光谱图像进行场景分类。
1. 数据准备
数据来自天宫二号遥感图像自然场景分类数据集(NaSC-TG2),AI Studio链接。该数据集由中国科学院空间应用工程与技术中心发布,数据来自中国的天宫二号卫星,相关论文介绍。
天宫二号搭载的宽波段成像仪是新一代宽波段、宽视场和“图谱合一”的光学遥感器,在轨期间获取了海量的高质量对地观测影像。针对当前已开放的对地观测图像场景分类数据集存在的数量有限、场景相似、数据源单一等问题,面向天宫二号宽波段影像,提取自然场景类别,并精心考虑不同区域、不同时相,裁切、选取并标注天宫二号自然场景数据集。 图像格式: RGB三波段真彩色图像(jpg)和14波段的多光谱图像(tif) 场景类别:海滩、圆形农田、云、荒漠、林地、山脉、矩形农田、建筑区、河流和雪山 图像尺寸:128×128 图像数量:每种类别有2000张,共计20000张(jpg和tif均20000张)
In [1]
# 解压数据集
! mkdir -p dataset
! unzip -oq data/data86451/NaSC-TG2.zip -d dataset
In [3]
# 划分数据集
import os
import os.path as osp
import random
def create_datalist(data_folder):
random.seed(666)
clases = os.listdir(data_folder)
if ".ipynb_checkpoints" in clases:
clases.remove(".ipynb_checkpoints")
clases = sorted(clases)
train_list = []
val_list = []
with open(osp.join(data_folder, "label_list.txt"), "w") as lf:
for i_clas, clas in enumerate(clases):
lf.write(str(i_clas) + " " + clas + "\n")
names = os.listdir(osp.join(data_folder, clas))
if ".ipynb_checkpoints" in names:
names.remove(".ipynb_checkpoints")
random.shuffle(names)
for i, name in enumerate(names):
if i % 10 == 0: # 训练集:验证集 = 9:1
val_list.append(osp.join(clas, name) + " " + str(i_clas) + "\n")
else:
train_list.append(osp.join(clas, name) + " " + str(i_clas) + "\n")
with open(osp.join(data_folder, "train_list.txt"), "w") as tf:
random.shuffle(train_list)
for train in train_list:
tf.write(train)
with open(osp.join(data_folder, "val_list.txt"), "w") as vf:
random.shuffle(val_list)
for val in val_list:
vf.write(val)
print("Finished!")
create_datalist("dataset/TIF")
Finished!
2. 包安装
这里主要需要进行两件事:
- 克隆PaddleRS并安装所需要的库
- 安装GDAL
PaddleRS是基于飞桨开发的遥感处理平台,支持遥感图像分类,目标检测,图像分割,以及变化检测等常用遥感任务,帮助开发者更便捷地完成从训练到部署全流程遥感深度学习应用,目前还在开发中,Github链接。
GDAL(Geospatial Data Abstraction Library)是一个在X/MIT许可协议下的开源栅格空间数据转换库。它利用抽象数据模型来表达所支持的各种文件格式。它还有一系列命令行工具来进行数据转换和处理,Github链接。
In [ ]
# 克隆项目
! git clone https://github.com/PaddleCV-SIG/PaddleRS.git
In [ ]
# 安装GDAL(可参考:https://aistudio.baidu.com/aistudio/datasetdetail/136010)
%cd data/data136010
! mv GDAL-3.4.1-cp37-cp37m-manylinux.whl GDAL-3.4.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl
! pip install GDAL-3.4.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl
%cd ~
# 安装PaddleRS的所需
! pip install -q -r PaddleRS/requirements.txt
In [1]
import sys
sys.path.append("PaddleRS")
(可选)关闭gdal的警告
这个高光谱数据会使得GDAL一直弹出如下警告:
TIFFRadDivectory:Sum of Photometric type-related color channels and ExtraSsmples doesn'tmatch SamplesPerPixel. Defining non-color channels as ExtraSamples.
但关闭后会使得所有的GDAL相关的报错和警告不再显示。
In [2]
# 关闭gdal的警告
from osgeo import gdal
gdal.PushErrorHandler("CPLQuietErrorHandler")
3. 构建数据集
由于原始数据共有可见光近红外波段共14个,出于速度考虑,我们可以使用BandSelecting
选择其中部分波段进行训练和预测。这里为了显示多光谱的处理能力,选择了4(>3)个波段使用,参考天宫二号宽波段成像仪青藏高原专题数据集中有关天宫二号数据描述部分,选择8、10、12三个通道(可合成的近似真彩色影像),以及与Landset8 OLI对应的近红外波段波长范围相同的4通道组成了4个波段的数据。
In [3]
import os
import os.path as osp
from paddlers.datasets import ClasDataset
from paddlers import transforms as T
# 定义数据增强
train_transforms = T.Compose([
T.BandSelecting([8, 10, 12, 4]), # (天宫二号)近似真彩色 + 近红外
T.Resize(target_size=128),
T.RandomVerticalFlip(),
T.RandomHorizontalFlip(),
T.Normalize(mean=[0.5] * 4, std=[0.5] * 4)
])
eval_transforms = T.Compose([
T.BandSelecting([8, 10, 12, 4]),
T.Resize(target_size=128),
T.Normalize(mean=[0.5] * 4, std=[0.5] * 4)
])
# 定义数据集
data_dir = "dataset/TIF"
train_file_list = osp.join(data_dir, 'train_list.txt')
val_file_list = osp.join(data_dir, 'val_list.txt')
label_file_list = osp.join(data_dir, 'label_list.txt')
train_dataset = ClasDataset(
data_dir=data_dir,
file_list=train_file_list,
label_list=label_file_list,
transforms=train_transforms,
shuffle=True
)
eval_dataset = ClasDataset(
data_dir=data_dir,
file_list=val_file_list,
label_list=label_file_list,
transforms=eval_transforms,
num_workers=0,
shuffle=False
)
2022-04-26 14:53:07 [INFO] 18000 samples in file dataset/TIF/train_list.txt 2022-04-26 14:53:07 [INFO] 2000 samples in file dataset/TIF/val_list.txt
(可选)查看数据
由train_dataset
中读取到的数据,在训练前我们可以确认下数据是否正常。下面将会打印第一个数据的形状(确认是否为4个通道)以及标签。并按照真彩色和标准假彩色进行合成显示。
标准假彩色是将近红外、红、绿三个波段分别显示为红、绿、蓝三种颜色。由于植被的近红外反射率特别高的缘故,在标准假彩色合成的显示中,植被显示为红色。
In [47]
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
data = train_dataset.__getitem__(25)
img = data["image"]
print("data shape: {} and label is {}.".format(img.shape, data["label"]))
rgb = ((img[:, :, :3] * 0.5 + 0.5) * 255).astype("uint8") # 由于进行了normal所以需要denormal
plt.subplot(121);plt.imshow(rgb);plt.title("R-G-B")
nirrg = np.stack([img[:, :, -1], img[:, :, 0], img[:, :, 1]], axis=-1) # 波段组合 NIR-R-G
nirrg = ((nirrg * 0.5 + 0.5) * 255).astype("uint8")
plt.subplot(122);plt.imshow(nirrg);plt.title("NIR-R-G")
plt.show()
data shape: (128, 128, 4) and label is 8.
<Figure size 432x288 with 2 Axes>
4. 训练及评估
由于目前PaddleRS中的分类网络部分采用的PaddleClas的网络,所以在输入通道上都默认使用了3个通道。我们添加了一个CondenseNetV2可通过设置in_channels=4
来接收4个通道的图像输入。
CondenseNetV2是CVPR2021中清华大学提出的网络,Github链接。
针对DenseNet的特征复用冗余,CondenseNet提出利用可学习分组卷积来裁剪掉冗余连接。然而,DenseNet和CondenseNet中特征一旦产生将不再发生任何更改,这就导致了部分特征的潜在价值被严重忽略。CondenseNetV2的想法是与其直接删掉冗余,不妨给冗余特征一个“翻身”机会。因此他们提出一种可学习的稀疏特征重激活的方法,来有选择地更新冗余特征,从而增强特征的复用效率。CondenseNetV2在CondenseNet的基础上引入了稀疏特征重激活,对冗余特征同时进行了裁剪和更新,有效提升了密集连接网络的特征复用效率,在图像分类和检测任务上取得的出色表现。
In [49]
from paddlers.tasks.classifier import CondenseNetV2_b
num_classes = len(train_dataset.labels)
model = CondenseNetV2_b(
in_channels=4,
num_classes=num_classes
)
In [ ]
model.train(
num_epochs=120,
train_dataset=train_dataset,
train_batch_size=1024,
eval_dataset=eval_dataset,
save_interval_epochs=4,
log_interval_steps=5,
pretrain_weights="output/best_model/model.pdparams",
save_dir="output",
learning_rate=3e-4,
early_stop=True,
use_vdl=True
)
In [5]
from paddlers.tasks.load_model import load_model
model = load_model("output/best_model")
2022-04-26 14:53:19 [INFO] Model[CondenseNetV2_b] loaded.
In [ ]
# 评估
model.evaluate(eval_dataset)
OrderedDict([('top1', 0.795), ('top5', 0.993)])
5. 推理
In [6]
# 推理
# PR: https://github.com/PaddleCV-SIG/PaddleRS/pull/75
# 暂未合并,推理时需要修改一下
import matplotlib.pyplot as plt
from osgeo import gdal
from paddlers.transforms.functions import select_bands, to_uint8
gdal.PushErrorHandler("CPLQuietErrorHandler")
%matplotlib inline
img_path = "dataset/TIF/cloud/cloud_0043.tif"
img = gdal.Open(img_path).ReadAsArray().transpose((1, 2, 0))
print(img.shape)
img = to_uint8(select_bands(img, [8, 10, 12])) # 近似真彩色
print(img.shape)
plt.imshow(img)
plt.show()
pred = model.predict(img_path, eval_transforms)
print(pred)
(128, 128, 14) (128, 128, 3)
<Figure size 432x288 with 1 Axes>
{'class_ids': 2, 'scores': 0.91036, 'label_names': 'cloud'}