基于Python的高光谱图像监督分类与非监督分类
- 高光谱数据:点击此处可下载
- 用到的库:matplotlib、scipy、spectral、numpy
- 主要内容:监督分类(最大似然法)与非监督分类(K-means)及精度评定
欢迎有兴趣的朋友交流指点。最后,废话不多说直接上代码😄
import matplotlib.pyplot as plt
from scipy.io import loadmat
import spectral as spy
import numpy as np
def unsupervised_demo(src): # K-means法 迭代方法生成聚类
m, c = spy.kmeans(src, nclusters=6, max_iterations=30) # 分为6类,最大迭代30次
spy.imshow(classes=m) # 显示分类结果
plt.figure()
for i in range(c.shape[0]): # 显示分类后的各光谱曲线
plt.plot(c[i])
plt.pause(60)
def supervised_demo(src,gt):
classes = spy.create_training_classes(src, gt) # 创建训练类集合
gmlc = spy.GaussianClassifier(classes) # 高斯的最大似然分类法
clmap = gmlc.classify_image(src)
spy.imshow(classes=clmap)
gt_results = clmap * (gt != 0) # 为分好类的图像设定一个Mask
gt_right = gt_results * (gt_results == gt)
gt_errors = gt_results * (gt_results != gt)
spy.imshow(classes=gt_right, title="right") # 分类正确的部分
spy.imshow(classes=gt_errors, title="errors") # 分类错误的部分
precision_evaluation(gt_results ,gt) # 精度评定
plt.pause(60)
def precision_evaluation(cla, gt): # 精度评定
def count_number(src): # 统计分类数据
dict_k = {}
for row in range(src.shape[0]):
for col in range(src.shape[1]):
if src[row][col] not in dict_k:
dict_k[src[row][col]] = 0
dict_k[src[row][col]] += 1
dict_k = dict(sorted(dict_k.items()))
del dict_k[0] # 键为0的是未归类的部分,所以去掉
class_sum = sum(dict_k.values())
return dict_k, class_sum
cla_dic, cla_sum = count_number(cla) # 分类后的
gt_dic, gt_sum = count_number(gt) # 真实的
gt_right = cla * (cla == gt)
gt_right_dic, gt_right_sum = count_number(gt_right) # 分类正确的
p0 = gt_right_sum / gt_sum
pe = 0
for gt_key in gt_dic:
if gt_key not in cla_dic:
cla_dic[gt_key] = 0
gt_right_dic[gt_key] = 0
print("类别%s的用户精度为:0.0000,生产者精度为:0.0000" % gt_key)
else:
print("类别%s的用户精度为:%.4f," % (gt_key, gt_right_dic[gt_key] / cla_dic[gt_key]), end='')
print("生产者精度为:%.4f" % (gt_right_dic[gt_key] / gt_dic[gt_key]))
pe += gt_dic[gt_key] * cla_dic[gt_key]
pe = pe / (gt_sum * gt_sum)
kappa = (p0 - pe) / (1 - pe)
overall_accuracy = gt_right_sum / gt_sum
print("-" * 36)
print("Kappa=", kappa)
print("overall_accuracy", overall_accuracy)
# 加载mat格式的数据。loadmat输出的是dict,所以需要进行定位
input_image = loadmat('D:/Hyper/Indian_pines_corrected.mat')['indian_pines_corrected']
# input_image1 = loadmat('D:/Hyper/Salinas_corrected.mat')['salinas_corrected'] # 其他数据同理
gt = loadmat("D:/Hyper/Indian_pines_gt.mat")['indian_pines_gt'] # 加载真实类别
# 可视化影像
view = spy.imshow(data=input_image, bands=[69, 27, 11], figsize=(6, 6))
plt.pause(60)
# 分类
unsupervised_demo(input_image)
supervised_demo(input_image,gt)
可视化结果:
非监督分类结果:
类别光谱结果:
监督分类结果:
分类正确部分:
分类错误部分:
Q:为什么会有椒盐斑点?
A:感觉应该是噪声影响。
分类结果精度:
D:\Software\Anaconda3\python.exe D:/Software/temp/test.py
类别1的用户精度为:0.0000,生产者精度为:0.0000
类别2的用户精度为:0.9008,生产者精度为:0.9475
类别3的用户精度为:0.9878,生产者精度为:0.9783
类别4的用户精度为:1.0000,生产者精度为:1.0000
类别5的用户精度为:0.9979,生产者精度为:0.9979
类别6的用户精度为:0.9455,生产者精度为:0.9973
类别7的用户精度为:0.0000,生产者精度为:0.0000
类别8的用户精度为:0.8918,生产者精度为:1.0000
类别9的用户精度为:0.0000,生产者精度为:0.0000
类别10的用户精度为:0.8642,生产者精度为:0.9949
类别11的用户精度为:0.9815,生产者精度为:0.9059
类别12的用户精度为:0.8938,生产者精度为:0.9933
类别13的用户精度为:1.0000,生产者精度为:1.0000
类别14的用户精度为:0.9953,生产者精度为:0.9960
类别15的用户精度为:0.9948,生产者精度为:0.9896
类别16的用户精度为:0.0000,生产者精度为:0.0000
------------------------------------
Kappa= 0.9409099615877274
overall_accuracy 0.9480924968289589
Process finished with exit code 0