MeanShift聚类-02python案例
原创
©著作权归作者所有:来自51CTO博客作者维格堂406小队的原创作品,请联系作者获取转载授权,否则将追究法律责任
Intro
Meanshift的使用案例~
数据引入
from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
from itertools import cycle
import numpy as np
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
from sklearn.datasets import load_iris
import pandas as pd
pd.set_option('display.max_rows', 500) # 打印最大行数
pd.set_option('display.max_columns', 500) # 打印最大列数
# 检查是否是array格式,如果不是,转换成array
from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import
iris_df = pd.DataFrame(
load_iris()["data"],
columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])
iris_df["target"] = load_iris()["target"]
iris_df.head()
| sepal_length
| sepal_width
| petal_length
| petal_width
| target
|
0
| 5.1
| 3.5
| 1.4
| 0.2
| 0
|
1
| 4.9
| 3.0
| 1.4
| 0.2
| 0
|
2
| 4.7
| 3.2
| 1.3
| 0.2
| 0
|
3
| 4.6
| 3.1
| 1.5
| 0.2
| 0
|
4
| 5.0
| 3.6
| 1.4
| 0.2
| 0
|
iris_df.groupby(by="target").describe()
| sepal_length
| sepal_width
| petal_length
| petal_width
|
| count
| mean
| std
| min
| 25%
| 50%
| 75%
| max
| count
| mean
| std
| min
| 25%
| 50%
| 75%
| max
| count
| mean
| std
| min
| 25%
| 50%
| 75%
| max
| count
| mean
| std
| min
| 25%
| 50%
| 75%
| max
|
target
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
0
| 50.0
| 5.006
| 0.352490
| 4.3
| 4.800
| 5.0
| 5.2
| 5.8
| 50.0
| 3.428
| 0.379064
| 2.3
| 3.200
| 3.4
| 3.675
| 4.4
| 50.0
| 1.462
| 0.173664
| 1.0
| 1.4
| 1.50
| 1.575
| 1.9
| 50.0
| 0.246
| 0.105386
| 0.1
| 0.2
| 0.2
| 0.3
| 0.6
|
1
| 50.0
| 5.936
| 0.516171
| 4.9
| 5.600
| 5.9
| 6.3
| 7.0
| 50.0
| 2.770
| 0.313798
| 2.0
| 2.525
| 2.8
| 3.000
| 3.4
| 50.0
| 4.260
| 0.469911
| 3.0
| 4.0
| 4.35
| 4.600
| 5.1
| 50.0
| 1.326
| 0.197753
| 1.0
| 1.2
| 1.3
| 1.5
| 1.8
|
2
| 50.0
| 6.588
| 0.635880
| 4.9
| 6.225
| 6.5
| 6.9
| 7.9
| 50.0
| 2.974
| 0.322497
| 2.2
| 2.800
| 3.0
| 3.175
| 3.8
| 50.0
| 5.552
| 0.551895
| 4.5
| 5.1
| 5.55
| 5.875
| 6.9
| 50.0
| 2.026
| 0.274650
| 1.4
| 1.8
| 2.0
| 2.3
| 2.5
|
从数据上看,三个种类之间,petal_length和petal_width的差异比较大,用它来画图。
# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors =["red","yellow","blue"]
marker = ["o","*","+"]
for k, col,mark in zip(range(3), colors,marker):
sub_data = iris_df.query("target==%s"%k)
plt.plot(sub_data.petal_length, sub_data.petal_width,"o", markerfacecolor=col,
markeredgecolor='k', markersize=5)
plt.show()
可以看到红色点和其余点相差很多,蓝色和黄色有部分点交错在一起
默认参数进行聚类
# ms = MeanShift( bin_seeding=True,cluster_all=False)
bandwidth = 0.726
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(iris_df[["petal_length", "petal_width"]])
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
print("number of estimated clusters : %d" % n_clusters_)
# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle
plt.figure(1)
plt.clf()
# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors = ["yellow", "red", "blue"]
marker = ["o", "*", "+"]
for k, col, mark in zip(range(n_clusters_), colors, marker):
my_members = labels == k
cluster_center = cluster_centers[k]
plt.plot(iris_df[my_members].petal_length,
iris_df[my_members].petal_width,
".",
markerfacecolor=col,
markeredgecolor='k',
markersize=6)
plt.plot(cluster_center[0],
cluster_center[1],
'o',
markerfacecolor=col,
markeredgecolor='k',
markersize=14)
circle = plt.Circle((cluster_center[0], cluster_center[1]),
bandwidth,
color='black',
fill=False)
plt.gcf().gca().add_artist(circle)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()
number of estimated clusters : 3
从图上看,红色部分自成一派,聚类效果就好,蓝黄两类互有交叉,以最靠近的类别中心来打label.
estimate_bandwidth方法
根据聚类的原始数据,生成建议的bandwidth,基础逻辑:
- 先抽样,获取部分样本
- 计算这样样本和所有点的最大距离
- 对距离求平均
从逻辑上看,更像是找一个较大的距离,使得能涵盖更多的点
estimate_bandwidth(iris_df[["petal_length", "petal_width"]])
计算距离,check下
from sklearn.neighbors import
nbrs = NearestNeighbors(n_neighbors=len(iris_df), n_jobs=-1)
nbrs.fit(iris_df.iloc[:,[2,3]])
NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=-1, n_neighbors=150, p=2,
radius=1.0)
d, index = nbrs.kneighbors(iris_df.iloc[:,[2,3]],return_distance=True)
from functools import reduce #python 3
total_distance = reduce(lambda x,y: x+y,np.array(pd.DataFrame(d).iloc[:,1:150]).tolist())
stats.describe(total_distance)
DescribeResult(nobs=22350, minmax=(0.0, 6.262587324740471), mean=2.185682454621745, variance=2.6174775533104904, skewness=0.3422940721262964, kurtosis=-1.1637573960810108)
pd.DataFrame({"total_distance":total_distance}).describe()
| total_distance
|
count
| 22350.000000
|
mean
| 2.185682
|
std
| 1.617862
|
min
| 0.000000
|
25%
| 0.640312
|
50%
| 1.941649
|
75%
| 3.544009
|
max
| 6.262587
|
从数据上看,有点接近25%分位数。
meanshift的简单介绍到此为止,有些业务场景下,这个算法还是很好用的。需要具体问题具体分析。
2021-03-31 于南京市江宁区九龙湖