最近在网上查看用MapReduce实现的Kmeans算法,例子是不错,
但注释太少了,而且参数太多,如果新手学习的话不太好理解。所以自己按照个人的理解写了一个简单的例子并添加了详细的注释。
大致的步骤是:
1,Map每读取一条数据就与中心做对比,求出该条记录对应的中心,然后以中心的ID为Key,该条数据为value将数据输出。
2,利用reduce的归并功能将相同的Key归并到一起,集中与该Key对应的数据,再求出这些数据的平均值,输出平均值。
3,对比reduce求出的平均值与原来的中心,如果不相同,这将清空原中心的数据文件,将reduce的结果写到中心文件中。(中心的值存在一个HDFS的文件中)
删掉reduce的输出目录以便下次输出。
继续运行任务。
4,对比reduce求出的平均值与原来的中心,如果相同。则删掉reduce的输出目录,运行一个没有reduce的任务将中心ID与值对应输出。
1 package MyKmeans;
2
3 import java.io.IOException;
4 import java.util.ArrayList;
5
6 import org.apache.hadoop.conf.Configuration;
7 import org.apache.hadoop.fs.Path;
8 import org.apache.hadoop.io.Text;
9
10 import java.util.Arrays;
11 import java.util.Iterator;
12
13 import org.apache.hadoop.io.IntWritable;
14 import org.apache.hadoop.io.LongWritable;
15 import org.apache.hadoop.mapreduce.Job;
16 import org.apache.hadoop.mapreduce.Mapper;
17 import org.apache.hadoop.mapreduce.Reducer;
18 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
19 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
20
21
22 public class MapReduce {
23
24 public static class Map extends Mapper<LongWritable, Text, IntWritable, Text>{
25
26 //中心集合
27 ArrayList<ArrayList<Double>> centers = null;
28 //用k个中心
29 int k = 0;
30
31 //读取中心
32 protected void setup(Context context) throws IOException,
33 InterruptedException {
34 centers = Utils.getCentersFromHDFS(context.getConfiguration().get("centersPath"),false);
35 k = centers.size();
36 }
37
38
39 /**
40 * 1.每次读取一条要分类的条记录与中心做对比,归类到对应的中心
41 * 2.以中心ID为key,中心包含的记录为value输出(例如: 1 0.2 。 1为聚类中心的ID,0.2为靠近聚类中心的某个值)
42 */
43 protected void map(LongWritable key, Text value, Context context)
44 throws IOException, InterruptedException {
45 //读取一行数据
46 ArrayList<Double> fileds = Utils.textToArray(value);
47 int sizeOfFileds = fileds.size();
48
49 double minDistance = 99999999;
50 int centerIndex = 0;
51
52 //依次取出k个中心点与当前读取的记录做计算
53 for(int i=0;i<k;i++){
54 double currentDistance = 0;
55 for(int j=0;j<sizeOfFileds;j++){
56 double centerPoint = Math.abs(centers.get(i).get(j));
57 double filed = Math.abs(fileds.get(j));
58 currentDistance += Math.pow((centerPoint - filed) / (centerPoint + filed), 2);
59 }
60 //循环找出距离该记录最接近的中心点的ID
61 if(currentDistance<minDistance){
62 minDistance = currentDistance;
63 centerIndex = i;
64 }
65 }
66 //以中心点为Key 将记录原样输出
67 context.write(new IntWritable(centerIndex+1), value);
68 }
69
70 }
71
72 //利用reduce的归并功能以中心为Key将记录归并到一起
73 public static class Reduce extends Reducer<IntWritable, Text, Text, Text>{
74
75 /**
76 * 1.Key为聚类中心的ID value为该中心的记录集合
77 * 2.计数所有记录元素的平均值,求出新的中心
78 */
79 protected void reduce(IntWritable key, Iterable<Text> value,Context context)
80 throws IOException, InterruptedException {
81 ArrayList<ArrayList<Double>> filedsList = new ArrayList<ArrayList<Double>>();
82
83 //依次读取记录集,每行为一个ArrayList<Double>
84 for(Iterator<Text> it =value.iterator();it.hasNext();){
85 ArrayList<Double> tempList = Utils.textToArray(it.next());
86 filedsList.add(tempList);
87 }
88
89 //计算新的中心
90 //每行的元素个数
91 int filedSize = filedsList.get(0).size();
92 double[] avg = new double[filedSize];
93 for(int i=0;i<filedSize;i++){
94 //求没列的平均值
95 double sum = 0;
96 int size = filedsList.size();
97 for(int j=0;j<size;j++){
98 sum += filedsList.get(j).get(i);
99 }
100 avg[i] = sum / size;
101 }
102 context.write(new Text("") , new Text(Arrays.toString(avg).replace("[", "").replace("]", "")));
103 }
104
105 }
106
107 @SuppressWarnings("deprecation")
108 public static void run(String centerPath,String dataPath,String newCenterPath,boolean runReduce) throws IOException, ClassNotFoundException, InterruptedException{
109
110 Configuration conf = new Configuration();
111 conf.set("centersPath", centerPath);
112
113 Job job = new Job(conf, "mykmeans");
114 job.setJarByClass(MapReduce.class);
115
116 job.setMapperClass(Map.class);
117
118 job.setMapOutputKeyClass(IntWritable.class);
119 job.setMapOutputValueClass(Text.class);
120
121 if(runReduce){
122 //最后依次输出不许要reduce
123 job.setReducerClass(Reduce.class);
124 job.setOutputKeyClass(Text.class);
125 job.setOutputValueClass(Text.class);
126 }
127
128 FileInputFormat.addInputPath(job, new Path(dataPath));
129
130 FileOutputFormat.setOutputPath(job, new Path(newCenterPath));
131
132 System.out.println(job.waitForCompletion(true));
133 }
134
135 public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
136 String centerPath = "hdfs://localhost:9000/input/centers.txt";
137 String dataPath = "hdfs://localhost:9000/input/wine.txt";
138 String newCenterPath = "hdfs://localhost:9000/out/kmean";
139
140 int count = 0;
141
142
143 while(true){
144 run(centerPath,dataPath,newCenterPath,true);
145 System.out.println(" 第 " + ++count + " 次计算 ");
146 if(Utils.compareCenters(centerPath,newCenterPath )){
147 run(centerPath,dataPath,newCenterPath,false);
148 break;
149 }
150 }
151 }
152
153 }
1 package MyKmeans;
2
3 import java.io.IOException;
4 import java.util.ArrayList;
5 import java.util.List;
6
7 import org.apache.hadoop.conf.Configuration;
8 import org.apache.hadoop.fs.FSDataInputStream;
9 import org.apache.hadoop.fs.FSDataOutputStream;
10 import org.apache.hadoop.fs.FileStatus;
11 import org.apache.hadoop.fs.FileSystem;
12 import org.apache.hadoop.fs.Path;
13 import org.apache.hadoop.io.IOUtils;
14 import org.apache.hadoop.io.Text;
15 import org.apache.hadoop.util.LineReader;
16
17 public class Utils {
18
19 //读取中心文件的数据
20 public static ArrayList<ArrayList<Double>> getCentersFromHDFS(String centersPath,boolean isDirectory) throws IOException{
21
22 ArrayList<ArrayList<Double>> result = new ArrayList<ArrayList<Double>>();
23
24 Path path = new Path(centersPath);
25
26 Configuration conf = new Configuration();
27
28 FileSystem fileSystem = path.getFileSystem(conf);
29
30 if(isDirectory){
31 FileStatus[] listFile = fileSystem.listStatus(path);
32 for (int i = 0; i < listFile.length; i++) {
33 result.addAll(getCentersFromHDFS(listFile[i].getPath().toString(),false));
34 }
35 return result;
36 }
37
38 FSDataInputStream fsis = fileSystem.open(path);
39 LineReader lineReader = new LineReader(fsis, conf);
40
41 Text line = new Text();
42
43 while(lineReader.readLine(line) > 0){
44 ArrayList<Double> tempList = textToArray(line);
45 result.add(tempList);
46 }
47 lineReader.close();
48 return result;
49 }
50
51 //删掉文件
52 public static void deletePath(String pathStr) throws IOException{
53 Configuration conf = new Configuration();
54 Path path = new Path(pathStr);
55 FileSystem hdfs = path.getFileSystem(conf);
56 hdfs.delete(path ,true);
57 }
58
59 public static ArrayList<Double> textToArray(Text text){
60 ArrayList<Double> list = new ArrayList<Double>();
61 String[] fileds = text.toString().split(",");
62 for(int i=0;i<fileds.length;i++){
63 list.add(Double.parseDouble(fileds[i]));
64 }
65 return list;
66 }
67
68 public static boolean compareCenters(String centerPath,String newPath) throws IOException{
69
70 List<ArrayList<Double>> oldCenters = Utils.getCentersFromHDFS(centerPath,false);
71 List<ArrayList<Double>> newCenters = Utils.getCentersFromHDFS(newPath,true);
72
73 int size = oldCenters.size();
74 int fildSize = oldCenters.get(0).size();
75 double distance = 0;
76 for(int i=0;i<size;i++){
77 for(int j=0;j<fildSize;j++){
78 double t1 = Math.abs(oldCenters.get(i).get(j));
79 double t2 = Math.abs(newCenters.get(i).get(j));
80 distance += Math.pow((t1 - t2) / (t1 + t2), 2);
81 }
82 }
83
84 if(distance == 0.0){
85 //删掉新的中心文件以便最后依次归类输出
86 Utils.deletePath(newPath);
87 return true;
88 }else{
89 //先清空中心文件,将新的中心文件复制到中心文件中,再删掉中心文件
90
91 Configuration conf = new Configuration();
92 Path outPath = new Path(centerPath);
93 FileSystem fileSystem = outPath.getFileSystem(conf);
94
95 FSDataOutputStream overWrite = fileSystem.create(outPath,true);
96 overWrite.writeChars("");
97 overWrite.close();
98
99
100 Path inPath = new Path(newPath);
101 FileStatus[] listFiles = fileSystem.listStatus(inPath);
102 for (int i = 0; i < listFiles.length; i++) {
103 FSDataOutputStream out = fileSystem.create(outPath);
104 FSDataInputStream in = fileSystem.open(listFiles[i].getPath());
105 IOUtils.copyBytes(in, out, 4096, true);
106 }
107 //删掉新的中心文件以便第二次任务运行输出
108 Utils.deletePath(newPath);
109 }
110
111 return false;
112 }
113 }
数据集 http://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data