Intro
工作中用到了meanshift,不追溯太复杂的原理以及各种算法变体,原始paper等等。只从概念上,对原理做简要的总结和介绍。主要逻辑,参考sklearn的源码。
和常用k-means一样,meanshift也是一个迭代算法。我们关注的无非以下几点:
- 迭代的逻辑,按照什么方式进行迭代
- 迭代终止的条件
- 怎么给样本打label
主要逻辑
sklearn的代码逻辑如下:
- 初始化:生成bandwidth和seeds。如果没有指定bandwidth和seeds,会根据样本生成
- 并行化完成所有seeds的迭代,每个seed相当于是一个单独的迭代过程
- 迭代终止条件:max_iter=300或者阈值小于1e-3 * bandwidth
- 首次迭代时,seed就是聚类中心点,获取中心点bandwidth范围内的所有样本点
- 重新计算中心点(即求均值),再执行上一步,直到满足迭代停止条件
- 聚类中心点合并
- 按照每个中心点覆盖样本数排序,依次计算各个中心点之间的距离
- 在bandwidth以内的两个中心点,保留覆盖样本数高的中心点
- 给样本打label
- 找到每个样本附近最近的中心点,则改样本就归属于该类
- cluster_all=False,则bandwidth之外的点命名为-1,否则还是最近的聚类中心
Ref
[1] https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html
2021-03-29 于南京市江宁区九龙湖