注:本算法的实现仅仅适用于小规模数据集的实验与测试,不适合用于工程应用
算法假定训练数据各属性列的值均是离散类型的。若是非离散类型的数据,需要首先进行数据的预处理,将非离散型的数据离散化。
算法中使用到了DecimalCaculate类,该类是java中BigDecimal类的扩展,用于高精度浮点数的运算。该类的实现同本人转载的一篇博文:对BigDecimal常用方法的归类中的Arith类相同。
算法实现的代码如下
[java] view plain copy print ?
1. package
2. import
3. import
4. import
5. import
6. /**
7. * 贝叶斯主体类
8. * @author Rowen
9. * @data 2011.03.15
13. */
14. public class
15. /**
16. * 将原训练元组按类别划分
17. * @param datas 训练元组
18. * @return Map<类别,属于该类别的训练元组>
19. */
20. Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){
21. new
22. null;
23. "";
24. for (int i = 0; i < datas.size(); i++) {
25. t = datas.get(i);
26. 1);
27. if
28. map.get(c).add(t);
29. else
30. new
31. nt.add(t);
32. map.put(c, nt);
33. }
34. }
35. return
36. }
37.
38. /**
39. * 在训练数据的基础上预测测试元组的类别
40. * @param datas 训练元组
41. * @param testT 测试元组
42. * @return 测试元组的类别
43. */
44. public
45. this.datasOfClass(datas);
46. Object classes[] = doc.keySet().toArray();
47. double maxP = 0.00;
48. int maxPIndex = -1;
49. for (int i = 0; i < doc.size(); i++) {
50. String c = classes[i].toString();
51. ArrayList<ArrayList<String>> d = doc.get(c);
52. double pOfC = DecimalCalculate.div(d.size(), datas.size(), 3);
53. for (int j = 0; j < testT.size(); j++) {
54. double pv = this.pOfV(d, testT.get(j), j);
55. pOfC = DecimalCalculate.mul(pOfC, pv);
56. }
57. if(pOfC > maxP){
58. maxP = pOfC;
59. maxPIndex = i;
60. }
61. }
62. return
63. }
64. /**
65. * 计算指定属性列上指定值出现的概率
66. * @param d 属于某一类的训练元组
67. * @param value 列值
68. * @param index 属性列索引
69. * @return 概率
70. */
71. private double pOfV(ArrayList<ArrayList<String>> d, String value, int
72. double p = 0.00;
73. int count = 0;
74. int
75. null;
76. for (int i = 0; i < total; i++) {
77. if(d.get(i).get(index).equals(value)){
78. count++;
79. }
80. }
81. 3);
82. return
83. }
84. }
package Bayes;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import util.DecimalCalculate;
/**
* 贝叶斯主体类
* @author Row
@data 2011.03.15
*/
public class Bayes {
/**
* 将原训练元组按类别划分
* @param datas 训练元组
* @return Map<类别,属于该类别的训练元组>
*/
Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){
Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>();
ArrayList<String> t = null;
String c = "";
for (int i = 0; i < datas.size(); i++) {
t = datas.get(i);
c = t.get(t.size() - 1);
if (map.containsKey(c)) {
map.get(c).add(t);
} else {
ArrayList<ArrayList<String>> nt = new ArrayList<ArrayList<String>>();
nt.add(t);
map.put(c, nt);
}
}
return map;
}
/**
* 在训练数据的基础上预测测试元组的类别
* @param datas 训练元组
* @param testT 测试元组
* @return 测试元组的类别
*/
public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) {
Map<String, ArrayList<ArrayList<String>>> doc = this.datasOfClass(datas);
Object classes[] = doc.keySet().toArray();
double maxP = 0.00;
int maxPIndex = -1;
for (int i = 0; i < doc.size(); i++) {
String c = classes[i].toString();
ArrayList<ArrayList<String>> d = doc.get(c);
double pOfC = DecimalCalculate.div(d.size(), datas.size(), 3);
for (int j = 0; j < testT.size(); j++) {
double pv = this.pOfV(d, testT.get(j), j);
pOfC = DecimalCalculate.mul(pOfC, pv);
}
if(pOfC > maxP){
maxP = pOfC;
maxPIndex = i;
}
}
return classes[maxPIndex].toString();
}
/**
* 计算指定属性列上指定值出现的概率
* @param d 属于某一类的训练元组
* @param value 列值
* @param index 属性列索引
* @return 概率
*/
private double pOfV(ArrayList<ArrayList<String>> d, String value, int index) {
double p = 0.00;
int count = 0;
int total = d.size();
ArrayList<String> t = null;
for (int i = 0; i < total; i++) {
if(d.get(i).get(index).equals(value)){
count++;
}
}
p = DecimalCalculate.div(count, total, 3);
return p;
}
}
算法测试类:
[java] view plain copy print ?
1. package
2. import
3. import
4. import
5. import
6. import
7. /**
8. * 贝叶斯算法测试类
9. * @author Rowen
1* @data 2011.03.15
14. */
15. public class
16. /**
17. * 读取测试元组
18. * @return 一条测试元组
19. * @throws IOException
20. */
21. public ArrayList<String> readTestData() throws
22. new
23. new BufferedReader(new
24. "";
25. while (!(str = reader.readLine()).equals("")) {
26. new
27. while
28. candAttr.add(tokenizer.nextToken());
29. }
30. }
31. return
32. }
33.
34. /**
35. * 读取训练元组
36. * @return 训练元组集合
37. * @throws IOException
38. */
39. public ArrayList<ArrayList<String>> readData() throws
40. new
41. new BufferedReader(new
42. "";
43. while (!(str = reader.readLine()).equals("")) {
44. new
45. new
46. while
47. s.add(tokenizer.nextToken());
48. }
49. datas.add(s);
50. }
51. return
52. }
53.
54. public static void
55. new
56. null;
57. null;
58. new
59. try
60. "请输入训练数据");
61. datas = tb.readData();
62. while (true) {
63. "请输入测试元组");
64. testT = tb.readTestData();
65. String c = bayes.predictClass(datas, testT);
66. "The class is: "
67. }
68. catch
69. e.printStackTrace();
70. }
71. }
72. }
package Bayes;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;
/**
* 贝叶斯算法测试类
* @author
@data 2011.03.15
*/
public class TestBayes {
/**
* 读取测试元组
* @return 一条测试元组
* @throws IOException
*/
public ArrayList<String> readTestData() throws IOException{
ArrayList<String> candAttr = new ArrayList<String>();
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
String str = "";
while (!(str = reader.readLine()).equals("")) {
StringTokenizer tokenizer = new StringTokenizer(str);
while (tokenizer.hasMoreTokens()) {
candAttr.add(tokenizer.nextToken());
}
}
return candAttr;
}
/**
* 读取训练元组
* @return 训练元组集合
* @throws IOException
*/
public ArrayList<ArrayList<String>> readData() throws IOException {
ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
String str = "";
while (!(str = reader.readLine()).equals("")) {
StringTokenizer tokenizer = new StringTokenizer(str);
ArrayList<String> s = new ArrayList<String>();
while (tokenizer.hasMoreTokens()) {
s.add(tokenizer.nextToken());
}
datas.add(s);
}
return datas;
}
public static void main(String[] args) {
TestBayes tb = new TestBayes();
ArrayList<ArrayList<String>> datas = null;
ArrayList<String> testT = null;
Bayes bayes = new Bayes();
try {
System.out.println("请输入训练数据");
datas = tb.readData();
while (true) {
System.out.println("请输入测试元组");
testT = tb.readTestData();
String c = bayes.predictClass(datas, testT);
System.out.println("The class is: " + c);
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
训练数据:
[java] view plain copy print ?
1. youth high no fair no
2. youth high no excellent no
3. middle_aged high no fair yes
4. senior medium no fair yes
5. senior low yes fair yes
6. senior low yes excellent no
7. middle_aged low yes excellent yes
8. youth medium no fair no
9. youth low yes fair yes
10. senior medium yes fair yes
11. youth medium yes excellent yes
12. middle_aged medium no excellent yes
13. middle_aged high yes fair yes
14. senior medium no excellent no
youth high no fair no
youth high no excellent no
middle_aged high no fair yes
senior medium no fair yes
senior low yes fair yes
senior low yes excellent no
middle_aged low yes excellent yes
youth medium no fair no
youth low yes fair yes
senior medium yes fair yes
youth medium yes excellent yes
middle_aged medium no excellent yes
middle_aged high yes fair yes
senior medium no excellent no
对原训练数据进行测试,测试如果如下:
[c-sharp] view plain copy print ?
1. 请输入测试元组
2. youth high no fair
3. The class is: no
4. 请输入测试元组
5. youth high no excellent
6. The class is: no
7. 请输入测试元组
8. middle_aged high no fair
9. The class is: yes
10. 请输入测试元组
11. senior medium no fair
12. The class is: yes
13. 请输入测试元组
14. senior low yes fair
15. The class is: yes
16. 请输入测试元组
17. senior low yes excellent
18. The class is: yes
19. 请输入测试元组
20. middle_aged low yes excellent
21. The class is: yes
22. 请输入测试元组
23. youth medium no fair
24. The class is: no
25. 请输入测试元组
26. youth low yes fair
27. The class is: yes
28. 请输入测试元组
29. senior medium yes fair
30. The class is: yes
31. 请输入测试元组
32. youth medium yes excellent
33. The class is: yes
34. 请输入测试元组
35. middle_aged medium no excellent
36. The class is: yes
37. 请输入测试元组
38. middle_aged high yes fair
39. The class is: yes
40. 请输入测试元组
41. senior medium no excellent
42. The class is: no
请输入测试元组
youth high no fair
The class is: no
请输入测试元组
youth high no excellent
The class is: no
请输入测试元组
middle_aged high no fair
The class is: yes
请输入测试元组
senior medium no fair
The class is: yes
请输入测试元组
senior low yes fair
The class is: yes
请输入测试元组
senior low yes excellent
The class is: yes
请输入测试元组
middle_aged low yes excellent
The class is: yes
请输入测试元组
youth medium no fair
The class is: no
请输入测试元组
youth low yes fair
The class is: yes
请输入测试元组
senior medium yes fair
The class is: yes
请输入测试元组
youth medium yes excellent
The class is: yes
请输入测试元组
middle_aged medium no excellent
The class is: yes
请输入测试元组
middle_aged high yes fair
The class is: yes
请输入测试元组
senior medium no excellent
The class is: no
测试结果显示14个测试实例中有13个分类是正确的,正确率为93%,说明算法能够给出一个准确的预测与分类,但是算法还需改进以提高正确率。
改进的可选方法之一:
为避免单个属性值对分类结果的权重过大,例如当某属性值在某一类中出现0次时,该属性值就决定了测试实例已经不可能属于该类了,这就可能会造成误差,因此在计算概率时可能进行如下改进:
将原先的P(Xk|Ci)=|Xk| / |Ci| 改为 P(Xk|Ci)=(|Xk|+mp) / (|Ci|+m),其中m可设定为训练元组的个数,p为等可能假设的先验概率。