Intro

  Meanshift的使用案例~

数据引入

from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
from itertools import cycle
import numpy as np
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
from sklearn.datasets import load_iris
import pandas as pd
pd.set_option('display.max_rows', 500) # 打印最大行数
pd.set_option('display.max_columns', 500) # 打印最大列数
# 检查是否是array格式,如果不是,转换成array
from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import
iris_df = pd.DataFrame(
load_iris()["data"],
columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])
iris_df["target"] = load_iris()["target"]
iris_df.head()



sepal_length

sepal_width

petal_length

petal_width

target

0

5.1

3.5

1.4

0.2

0

1

4.9

3.0

1.4

0.2

0

2

4.7

3.2

1.3

0.2

0

3

4.6

3.1

1.5

0.2

0

4

5.0

3.6

1.4

0.2

0

iris_df.groupby(by="target").describe()



sepal_length

sepal_width

petal_length

petal_width

count

mean

std

min

25%

50%

75%

max

count

mean

std

min

25%

50%

75%

max

count

mean

std

min

25%

50%

75%

max

count

mean

std

min

25%

50%

75%

max

target

0

50.0

5.006

0.352490

4.3

4.800

5.0

5.2

5.8

50.0

3.428

0.379064

2.3

3.200

3.4

3.675

4.4

50.0

1.462

0.173664

1.0

1.4

1.50

1.575

1.9

50.0

0.246

0.105386

0.1

0.2

0.2

0.3

0.6

1

50.0

5.936

0.516171

4.9

5.600

5.9

6.3

7.0

50.0

2.770

0.313798

2.0

2.525

2.8

3.000

3.4

50.0

4.260

0.469911

3.0

4.0

4.35

4.600

5.1

50.0

1.326

0.197753

1.0

1.2

1.3

1.5

1.8

2

50.0

6.588

0.635880

4.9

6.225

6.5

6.9

7.9

50.0

2.974

0.322497

2.2

2.800

3.0

3.175

3.8

50.0

5.552

0.551895

4.5

5.1

5.55

5.875

6.9

50.0

2.026

0.274650

1.4

1.8

2.0

2.3

2.5

从数据上看,三个种类之间,petal_length和petal_width的差异比较大,用它来画图。

# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors =["red","yellow","blue"]
marker = ["o","*","+"]
for k, col,mark in zip(range(3), colors,marker):
sub_data = iris_df.query("target==%s"%k)
plt.plot(sub_data.petal_length, sub_data.petal_width,"o", markerfacecolor=col,
markeredgecolor='k', markersize=5)
plt.show()

MeanShift聚类-02python案例_聚类

可以看到红色点和其余点相差很多,蓝色和黄色有部分点交错在一起

默认参数进行聚类

# ms = MeanShift( bin_seeding=True,cluster_all=False)
bandwidth = 0.726
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(iris_df[["petal_length", "petal_width"]])
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle

plt.figure(1)
plt.clf()

# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors = ["yellow", "red", "blue"]
marker = ["o", "*", "+"]
for k, col, mark in zip(range(n_clusters_), colors, marker):
my_members = labels == k
cluster_center = cluster_centers[k]
plt.plot(iris_df[my_members].petal_length,
iris_df[my_members].petal_width,
".",
markerfacecolor=col,
markeredgecolor='k',
markersize=6)
plt.plot(cluster_center[0],
cluster_center[1],
'o',
markerfacecolor=col,
markeredgecolor='k',
markersize=14)
circle = plt.Circle((cluster_center[0], cluster_center[1]),
bandwidth,
color='black',
fill=False)
plt.gcf().gca().add_artist(circle)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()
number of estimated clusters : 3

MeanShift聚类-02python案例_聚类_02

从图上看,红色部分自成一派,聚类效果就好,蓝黄两类互有交叉,以最靠近的类别中心来打label.

estimate_bandwidth方法

根据聚类的原始数据,生成建议的bandwidth,基础逻辑:

  • 先抽样,获取部分样本
  • 计算这样样本和所有点的最大距离
  • 对距离求平均

从逻辑上看,更像是找一个较大的距离,使得能涵盖更多的点

estimate_bandwidth(iris_df[["petal_length", "petal_width"]])
0.7266371274126329

计算距离,check下

from sklearn.neighbors import
nbrs = NearestNeighbors(n_neighbors=len(iris_df), n_jobs=-1)
nbrs.fit(iris_df.iloc[:,[2,3]])
NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=-1, n_neighbors=150, p=2,
radius=1.0)
d, index = nbrs.kneighbors(iris_df.iloc[:,[2,3]],return_distance=True)
from functools import reduce #python 3
total_distance = reduce(lambda x,y: x+y,np.array(pd.DataFrame(d).iloc[:,1:150]).tolist())
from scipy import
stats.describe(total_distance)
DescribeResult(nobs=22350, minmax=(0.0, 6.262587324740471), mean=2.185682454621745, variance=2.6174775533104904, skewness=0.3422940721262964, kurtosis=-1.1637573960810108)
pd.DataFrame({"total_distance":total_distance}).describe()



total_distance

count

22350.000000

mean

2.185682

std

1.617862

min

0.000000

25%

0.640312

50%

1.941649

75%

3.544009

max

6.262587

从数据上看,有点接近25%分位数。

meanshift的简单介绍到此为止,有些业务场景下,这个算法还是很好用的。需要具体问题具体分析。

                                2021-03-31 于南京市江宁区九龙湖