今天把自己写的一个机器学习算法库中的K-means算法整理了一下,因为这个算法较其他的相比相对独立,可以单独贴出来,不会引用太多的其他类(不过还是有点引用,不过引用些简单的功能,看类名就知道什么意思了)。
基本功能和规则为:
1.当然是进行k-means算法,对数据集(这里使用二维数组来表示数据集,行数为数据总数,列数为数据维度)进行N维聚类
2.可以指定收敛的阀值(convergenceDis默认为0.0001)
3.为避免局部最小,可以指定重复运行次数,通过设定replicates的数值来指定,默认为0,即只重复一次聚类过程
4.测试数据格式为每一行代表一个输入,用空格分隔输入的各个维度,为了计算结果不出太大意外,建议对原始数据进行归一化
首先上骨架代码:
View Code
package org.tadoo.ml.cluster.kmeans;
import
java.util.Random;
import
org.tadoo.ml.exception.ClusterException;
import
org.tadoo.ml.util.ArrayCompute;
import
org.tadoo.ml.util.Utils;
/**
* 使用K-means方法进行聚类
*
* <p>time:2011-6-1</p>
* @author
T. QIN
*/
public
class
KmeansCluster
{
private
double
[][] dataSet
=
null
;
private
int
k
=
0
;
private
double
[][] centers
=
null
;
private
double
totalSumOfdistances
=
0
;
private
boolean
convergence
=
false
;
private
int
iter;
private
double
convergenceDis
=
0.0001
;
private
int
replicates
=
0
;
private
KMCResult[] kmcresults
=
null
;
public
KmeansCluster(
double
[][] x,
int
k)
throws
ClusterException
{
if
(x
==
null
||
x.length
==
0
)
{
throw
new
ClusterException(
"
输入数据不可为空。
"
);
}
this
.dataSet
=
x;
this
.k
=
k;
this
.centers
=
new
double
[k][dataSet[
0
].length];
}
private
void
initKCenters()
{
Random r =
new
Random();
int
rn
=
r.nextInt(dataSet.length);
for
(
int
i
=
0
; i
<
this
.k; i
++
)
//
初始化k个中心
{
for
(
int
j
=
0
; j
<
dataSet[
0
].length; j
++
)
{
centers[i][j] =
dataSet[rn][j];
}
rn =
r.nextInt(dataSet.length);
}
}
public
void
train()
{
if
(replicates
>
1
)
{
kmcresults =
new
KMCResult[replicates];
for
(
int
i
=
0
; i
<
replicates; i
++
)
{
beginTrain();
kmcresults[i] =
new
KMCResult();
kmcresults[i].centers =
this
.centers;
kmcresults[i].sum =
this
.totalSumOfdistances;
kmcresults[i].iters =
this
.iter;
this
.centers
=
new
double
[k][dataSet[
0
].length];
}
}
else
{
beginTrain();
}
}
private
void
beginTrain()
{
int
rows
=
dataSet.length;
int
cols
=
dataSet[
0
].length;
int
[] c
=
new
int
[rows];
//
保存每个数据属于哪个中心
int
vote
=
0
;
//
如果某一中心收敛,则投票数可加一
iter
=
0
;
initKCenters();
convergence =
false
;
while
(
!
convergence)
{
double
minDistance
=
Double.MAX_VALUE;
double
currentDis
=
0.0
;
int
count
=
0
;
int
changedCenterNumber
=
0
;
double
[] temp
=
new
double
[cols];
totalSumOfdistances =
0
;
for
(
int
i
=
0
; i
<
rows; i
++
)
{
for
(
int
j
=
0
; j
<
this
.k; j
++
)
{
currentDis =
Utils.distance(dataSet[i], centers[j]);
if
(currentDis
<
minDistance)
{
minDistance =
currentDis;
c[i] =
j;
}
}
totalSumOfdistances +=
minDistance;
minDistance =
Double.MAX_VALUE;
}
for
(
int
i
=
0
; i
<
this
.k; i
++
)
{
for
(
int
j
=
0
; j
<
c.length; j
++
)
{
if
(c[j]
==
i)
{
temp =
Utils.add(temp, dataSet[j]);
count ++
;
}
}
if
(count
!=
0
)
{
temp =
ArrayCompute.devideC(temp, count);
if
(isCenterConvergence(centers[i], temp))
{
vote ++
;
}
centers[i] =
temp;
changedCenterNumber ++
;
}
count =
0
;
temp =
new
double
[cols];
}
iter ++
;
if
(vote
==
changedCenterNumber)
{
convergence =
true
;
}
vote =
0
;
changedCenterNumber =
0
;
}
}
/**
* 判断某中心是否收敛
*
* @param
center
* @param
pCenter
* @return
* @see
:
*/
private
boolean
isCenterConvergence(
double
[] center,
double
[] pCenter)
{
boolean
result
=
true
;
double
[] distance
=
ArrayCompute.minus(center, pCenter);
for
(
int
i
=
0
; i
<
distance.length; i
++
)
{
if
(Math.abs(distance[i])
>
convergenceDis)
{
result =
false
;
}
}
return
result;
}
/**
* dataSet的 get() 方法
* @return
double[][] dataSet.
*/
public
double
[][] getDataSet()
{
return
dataSet;
}
/**
* dataSet的 set() 方法
* @param
dataSet The dataSet to set.
*/
public
void
setDataSet(
double
[][] dataSet)
{
this
.dataSet
=
dataSet;
}
/**
* k的 get() 方法
* @return
int k.
*/
public
int
getK()
{
return
k;
}
/**
* k的 set() 方法
* @param
k The k to set.
*/
public
void
setK(
int
k)
{
this
.k
=
k;
}
/**
* centers的 get() 方法
* @return
double[][] centers.
*/
public
double
[][] getCenters()
{
return
centers;
}
/**
* centers的 set() 方法
* @param
centers The centers to set.
*/
public
void
setCenters(
double
[][] centers)
{
this
.centers
=
centers;
}
/**
* totalSumOfdistances的 get() 方法
* @return
double totalSumOfdistances.
*/
public
double
getTotalSumOfdistances()
{
return
totalSumOfdistances;
}
/**
* totalSumOfdistances的 set() 方法
* @param
totalSumOfdistances The totalSumOfdistances to set.
*/
public
void
setTotalSumOfdistances(
double
totalSumOfdistances)
{
this
.totalSumOfdistances
=
totalSumOfdistances;
}
/**
* iter的 get() 方法
* @return
int iter.
*/
public
int
getIter()
{
return
iter;
}
/**
* convergenceDis的 get() 方法
* @return
double convergenceDis.
*/
public
double
getConvergenceDis()
{
return
convergenceDis;
}
/**
* convergenceDis的 set() 方法
* @param
convergenceDis The convergenceDis to set.
*/
public
void
setConvergenceDis(
double
convergenceDis)
{
this
.convergenceDis
=
convergenceDis;
}
/**
* replicates的 get() 方法
* @return
int replicates.
*/
public
int
getReplicates()
{
return
replicates;
}
/**
* replicates的 set() 方法
* @param
replicates The replicates to set.
*/
public
void
setReplicates(
int
replicates)
{
this
.replicates
=
replicates;
}
/**
* kmcresults的 get() 方法
* @return
KMCResult[] kmcresults.
*/
public
KMCResult[] getKmcresults()
{
return
kmcresults;
}
/**
* kmcresults的 set() 方法
* @param
kmcresults The kmcresults to set.
*/
public
void
setKmcresults(KMCResult[] kmcresults)
{
this
.kmcresults
=
kmcresults;
}
/**
* 聚类运行的结果
*
* <p>time:2011-6-2</p>
* @author
T. QIN
*/
public
class
KMCResult
{
public
double
[][] centers;
public
double
sum;
public
int
iters;
}
}
然后相关类:
View Code
package org.tadoo.ml.exception;
/**
* 聚类异常
*
* <p>time:2011-5-25</p>
* @author T. QIN
*/
public
class
ClusterException
extends
RuntimeException
{
public ClusterException()
{
super ();
}
public ClusterException(String s)
{
super (s);
}
}
View Code
package org.tadoo.ml.util;
/**
* 简单数组计算
*
* <p>time:2011-5-27</p>
* @author T. QIN
*/
public class
ArrayCompute
{
/**
* 数组相加
*
* @param x1
* @param x2
* @return
* @see :
*/
public static
double
[] add(
final
double
[] x1,
final
double
[] x2)
{
if (x1.length !=
x2.length)
{
System.err.print( " 向量长度不等不能相加! "
);
System.exit( 0 );
}
double [] result =
new
double
[x1.length];
for ( int
i
=
0
; i
<
result.length; i
++
)
{
result[i] = x1[i] +
x2[i];
}
return result;
}
/**
* 数组相减
*
* @param x1
* @param x2
* @return
* @see :
*/
public static
double
[] minus(
final
double
[] x1,
final
double
[] x2)
{
if (x1.length !=
x2.length)
{
System.err.print( " 向量长度不等不能相减! "
);
System.exit( 0 );
}
double [] result =
new
double
[x1.length];
for ( int
i
=
0
; i
<
result.length; i
++
)
{
result[i] = x1[i] -
x2[i];
}
return result;
}
/**
* 数组乘以一个常数
*
* @param x1
* @param c
* @return
* @see :
*/
public static
double
[] multiplyC(
final
double
[] x1,
final
double
c)
{
double [] ret =
new
double
[x1.length];
for ( int
i
=
0
; i
<
x1.length; i
++
)
{
ret[i] = x1[i] *
c;
}
return ret;
}
/**
* 数组除以一个常数
*
* @param x1
* @param c
* @return
* @see :
*/
public static
double
[] devideC(
final
double
[] x1,
final
double
c)
{
double [] ret =
new
double
[x1.length];
for ( int
i
=
0
; i
<
x1.length; i
++
)
{
ret[i] = x1[i] /
c;
}
return ret;
}
}
View Code
package org.tadoo.ml.util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
import org.tadoo.ml.Matrix;
/**
*
*
* <p>time:2011-3-23</p>
* @author T. QIN
*/
/**
*
*
* <p>time:2011-3-28</p>
* @author T. QIN
*/
public class Utils
{
/**
* 计算两个点之间的欧几里德距离
*
* @param x1
* @param x2
* @return
* @see :
*/
public static
double
distance(
double
[] x1,
double
[] x2)
{
double r =
0.0
;
for ( int i
=
0
; i
<
x1.length; i
++
)
{
r += Math.pow(x1[i] - x2[i],
2
);
}
return Math.sqrt(r);
}
/**
* 数组相加
*
* @param x1
* @param x2
* @return
* @see :
*/
public static
double
[] add(
final
double
[] x1,
final
double
[] x2)
{
if (x1.length != x2.length)
{
System.err.print( " 向量长度不等不能相加! " );
System.exit( 0 );
}
double [] result =
new
double
[x1.length];
for ( int i
=
0
; i
<
result.length; i
++
)
{
result[i] = x1[i] + x2[i];
}
return result;
}
}
Matrix
package org.tadoo.ml;
import java.io.PrintStream;
import org.tadoo.ml.exception.MatrixComputeException;
/**
* 矩阵结构
*
* <p>time:2011-3-23</p>
* @author T. QIN
*/
public class Matrix
{
private int rowNum;
private int colNum;
private double value[][];
/**
* 构造器方法
*
* @param rows 行数
* @param cols 列数
* @see :
* @author : T. QIN
*/
public Matrix( int rows, int
cols)
{
this .rowNum = rows;
this .colNum = cols;
this .value = new
double
[rows][cols];
}
/**
* 构造器方法
*
* @param rows 行数
* @param cols 列数
* @param isInitialMemory 是否初始化权值矩阵
* @see :
* @author : T. QIN
*/
public Matrix( int rows, int
cols,
boolean
isInitialMemory)
{
this .rowNum = rows;
this .colNum = cols;
if (isInitialMemory)
{
this .value = new
double
[rows][cols];
}
}
/**
* 替换矩阵值
*
* @param v
* @throws MatrixComputeException
* @see :
*/
public void changeWholeValue( double
v[][])
throws
MatrixComputeException
{
if (v.length != this
.rowNum
&&
v[
0
].length
!=
this
.colNum)
{
throw new MatrixComputeException( "
矩阵大小不拟合
"
);
}
this .value = v;
}
public void print(PrintStream ps)
{
if (ps == null
)
{
ps = System.out;
}
for ( int i =
0
; i
<
rowNum; i
++
)
{
for ( int j =
0
; j
<
colNum; j
++
)
{
ps.print(value[i][j] + " \t "
);
}
ps.println();
}
}
/**
* overwrite
*
* @return
* @see :
*/
public String toString()
{
StringBuffer sb = new StringBuffer();
for ( int i =
0
; i
<
rowNum; i
++
)
{
for ( int j =
0
; j
<
colNum; j
++
)
{
sb.append(value[i][j] + " \t "
);
}
sb.append( " \n " );
}
return sb.toString();
}
/**
* rowNum的 get() 方法
* @return int rowNum.
*/
public int getRowNum()
{
return rowNum;
}
/**
* rowNum的 set() 方法
* @param rowNum The rowNum to set.
*/
public void setRowNum( int
rowNum)
{
this .rowNum = rowNum;
}
/**
* colNum的 get() 方法
* @return int colNum.
*/
public int getColNum()
{
return colNum;
}
/**
* colNum的 set() 方法
* @param colNum The colNum to set.
*/
public void setColNum( int
colNum)
{
this .colNum = colNum;
}
/**
* value的 get() 方法
* @return double[][] value.
*/
public double [][] getValue()
{
return value;
}
/**
* value的 set() 方法
* @param value The value to set.
*/
public void setValue( double
[][] value)
{
this .value = value;
}
}
MatrixComputeException
package org.tadoo.ml.exception;
/**
*
*
* <p>time:2011-3-23</p>
* @author T. QIN
*/
public class MatrixComputeException extends Exception
{
public MatrixComputeException(String s)
{
super (s);
}
}
DataUtil
package org.tadoo.ml.util;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* 加载文件中的数据
*
* <p>time:2011-5-31</p>
* @author T. QIN
*/
public class DataUtil
{
/**
* 加载数据
*
* @param filePath
* @return
* @see :
*/
public static double [][] load(String filePath)
{
BufferedReader reader = null ;
List < String[] > container = new
ArrayList
<
String[]
>
();
String line = null ;
double [][] result = null ;
int xs, ys = 0 ;
try
{
reader = new BufferedReader( new FileReader( new
File(filePath)));
while ((line = reader.readLine()) != null
)
{
String temp[] = line.trim().split( " [\\s]+ " );
container.add(temp);
}
xs = (((String[]) container.get( 0 )).length);
ys = container.size(); // 数据条目
result = new double [ys][xs];
String[] strings = null ;
for ( int i = 0
, n
=
container.size(); i
<
n; i
++
)
{
strings = (String[]) container.get(i);
for ( int j = 0
; j
<
strings.length; j
++
)
{
result[i][j] = Double.parseDouble(strings[j]);
}
}
}
catch (FileNotFoundException e)
{
e.printStackTrace();
}
catch (IOException e)
{
e.printStackTrace();
}
return result;
}
// TODO:
/**
* 输出数据到文件,可选择某几列属性
*
* @param data
* @param saveFilename
* @param columns
* @see :
*/
public static void save( double
[][] data, String saveFilename,
int
[] columns)
{
BufferedWriter fp_saver = null ;
Arrays.sort(columns);
try
{
fp_saver = new BufferedWriter( new FileWriter(saveFilename));
for ( int i = 0
; i
<
data.length; i
++
)
{
for ( int j = 0
; j
<
columns.length; j
++
)
{
fp_saver.write(String.valueOf(data[i][columns[j]]) + " " );
}
fp_saver.write( " \n " );
}
fp_saver.flush();
}
catch (IOException e)
{
e.printStackTrace();
}
finally
{
try
{
fp_saver.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
}
}
然后是测试:
View Code
package org.tadoo.ml.test;
import junit.framework.TestCase;
import org.tadoo.ml.Matrix;
import org.tadoo.ml.cluster.kmeans.KmeansCluster;
import org.tadoo.ml.util.DataUtil;
import org.tadoo.ml.util.Utils;
/**
* 测试K-means聚类器
*
* <p>time:2011-6-2</p>
* @author T. QIN
*/
public class TestKmeansCluster extends TestCase
{
Matrix dataSet = null ;
double [][] ds = null ;
protected void setUp()
{
dataSet = Utils.uniformFileInputIntoFeatures( " D:\\test.s.txt " );
ds = DataUtil.load( " D:\\data1.txt " );
}
/**
* 测试用K-means选取中心节点
*
* @see :
*/
public void testKmeansCenters()
{
KmeansCluster kmc = new KmeansCluster(dataSet.getValue(), 2 );
kmc.train();
System.out.println(kmc.getTotalSumOfdistances());
System.out.println(kmc.getIter());
double [][] centers = kmc.getCenters();
for ( int i = 0 ; i
<
centers.length; i
++
)
{
for ( int j = 0 ; j
<
centers[i].length; j
++
)
{
System.out.print(centers[i][j] + " \t " );
}
System.out.println();
}
}
public void testKmeansReplicate(){
KmeansCluster kmc = new KmeansCluster(dataSet.getValue(), 11 );
kmc.setReplicates( 12 );
kmc.train();
KmeansCluster.KMCResult[] kmcr = kmc.getKmcresults();
for ( int i = 0 ; i
<
kmcr.length; i
++
)
{
System.out.println( " iters: " + kmcr[i].iters + " \tSum:
"
+
kmcr[i].sum);
}
}
}