面向对象技术与方法(JAVA)
关键词:面向对象JAVA、机器学习、聚类算法、K近邻算法、手写数字识别
背景信息
这是2022年春季学期,北京理工大学,计算机学院,软件工程,开设的一门必修课程,为期八周,共32课时,难度较大,教师课件使用英文讲授,自学的收获更多。
成绩主要分为三部分,作业、实验、考试。本篇内容为实验题,共六个大题,其中两道难度非常高,以当时的水平写出来实属不易。其他题目更多的是了解面向对象语言的特点和语法。
本篇内容是实验题中的难题。
2022面向对象JAVA实验
目录
(点击跳转)
- K近邻算法与手写数字识别
- K-means聚类算法
一、K近邻算法与手写数字识别
K近邻算法是分类数据最简单有效的算法,它采用基于实例的学习方法。简单地说,它采用测量不同样本之间距离的方法进行分类。它的工作原理是:存在一个样本数据集合,也称为训练样本集,并且样本集中的每个数据都有标签,即我们知道每个数据所属的分类。输入没有标签的新数据之后,将新数据的每个特征与样本集中数据的对应特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前K个最相似的数据,这就是K近邻算法中K的出处。最后,选择K个最相似数据中出现次数最多的分类。
要求设计和实现类KnnNumber,构造一个使用K近邻分类器的手写识别系统,该系统可以识别数字0到9.需要识别的数字已经处理成具有相同的色彩和大小:使用文本格式表示的32像素*32像素黑白图像(0/1二值图像)。
数据集在digits.zip中。其中目录trainingDigits包含了训练样本集,其中包含了大约2000个样本数据,每个数据的文件名表明了它的标签(0~9的某个数字),每个数字大约有200个样本数据;目录testDigits中包含了大约900个测试数据。请合理设计KnnNumber的数据成员和成员方法,以实现算法的各个步骤。将K取值为[3,9]之间的一个整数,找出分类准确率最高的K值。
。
思路
根据要求,我们首先要实现两个文件夹所有数据的读取,训练数据与测试数据使用两个Digits类来读取与储存,其中Digits包含File[]的列表与 Unit列表,分别用于读取与存储。
首先我们在主函数中建立两个文件标签sa1 sa2分别指向两个文件夹,继而建立对应的Digits成员,并调用构造器完成文件的读取。每个文件夹中包含0~9十种txt文本,每种数量不定,但是拥有共同特征,如0都包含“0_”这样的字段,通过.contain()方法筛选出所有0的文本文件,存入列表files中的一个文件数组中,重复十次,为了简化代码,获取文件目录的时候还使用了lambda表达式。读取文本文件使用了FileReader类并用Bufferedreader类包装,逐行读取字符串。
train与test的数据都读取完毕后,从test中逐个地全部遍历train文件中所有的txt文件,遍历过程中对train中每个txt文件的rank成员赋值,即两个图片的欧氏距离,由于开方运算并不会影响距离的大小排序,故省去了开方这一步,也可以说这是汉明距离,对最终的结果不会影响。对test中的一个txt文件,train的每个txt文件已经得到rank值,对train中所有数据进行升序排序,选取K值,得到前K个文件,按照标签出现的频次决定test对象的标签,将此过程执行于所有的test对象,与实际的标签对比,计算出正确率;改变K值,找出最高正确率的K值。
收获
该题目考察了文件目录的获取,文件的筛选,文件的读取,大量数据的存储与计算,以及K近邻算法的应用。
首先使用两个Digits类来读取与储存训练数据和测试数据,Digits类的数据成员有包含File[]的列表与 Unit列表,分别用于读取与存储。继而在主函数中建立两个sa1 sa2(sample的前两个字母)对象分别指向两个目标文件夹,然后建立对应的Digits成员,并调用构造器完成文件的读取。
两个文件夹中各包含0~9十种txt文本,数量不定,但是拥有共同特征,如0文件都包含 “0_” 这样的字段,通过使用字符串调用.contain()方法筛选出所有0的文本文件,存入列表files中的同一个文件数组中,重复十次,代码中还使用lambda表达式简化代码。
读取文本文件使用了Bufferedreader类,逐行读取字符串。
train与test的数据都读取完毕后,从test中逐个地全部遍历train文件中所有的txt文件,遍历过程中对train中每个txt文件的rank成员赋值,即两个图片的欧氏距离,由于开方运算并不会影响距离的大小排序,故省去了开方这一步,也可以说这是汉明距离,对最终的结果不会影响。
对test中的一个txt文件,train的每个txt文件已经得到rank值,对train中所有数据进行升序排序,选取K值,得到前K个文件,按照标签出现的频次决定test对象的标签,将此过程执行于所有的test对象,与实际的标签对比,计算出正确率;改变K值,找出最高正确率的K值。
源码
package java_exp2022;
import java.util.*;
import java.io.*;
/*由于欧氏距离的开平方计算并不会影响大小顺序,故本算法忽略*/
class Unit{
//这是将每个数字文本文件(如 0_12.txt )转化为一个类Unit
int label = -1;//标签初始化为-1,防止出错
Integer temp_rank = 0;//这里是排序依据,根据题目需要,这里可以赋值欧氏距离
ArrayList<String> line32 = new ArrayList<String>();
/*建立一个列表容纳三十二行01数据,存储为String类*/
Unit(int digit, ArrayList<String> lines){
//使用构造器赋值
label = digit ;
for(String s : lines)
line32.add(s);
}/*每个单元存入 数字 + 32行 */
}
class KnnCompare{//K近邻算法的计数单元
int label = -1;
Integer counter = 0;
//使用对象类型便于调用.compareTo方法
KnnCompare(int i){//构造器赋值
label = i;
}
}
class KnnNumber{
ArrayList<KnnCompare> K_count = new ArrayList<KnnCompare>();
//建立一个列表,方便对于K个数目的预测目标计数与排序
boolean check(int label,Digits train) {
K_count.add(new KnnCompare(0));
K_count.add(new KnnCompare(1));
K_count.add(new KnnCompare(2));
K_count.add(new KnnCompare(3));
K_count.add(new KnnCompare(4));
K_count.add(new KnnCompare(5));
K_count.add(new KnnCompare(6));
K_count.add(new KnnCompare(7));
K_count.add(new KnnCompare(8));
K_count.add(new KnnCompare(9));
//将对象添加入列表
int outcome = -2;//最终输出的标签,初始化以防出错
int K_value = 3;//K值
for(int i = 0; i < K_value ;i++)
K_count.get(train.ulist.get(i).label).counter++;
//对前K个标签预测进行计数
K_count.sort((KnnCompare k1,KnnCompare k2) -> k2.counter.compareTo(k1.counter) );
//对列表进行降序排列,此处使用了lambda表达式
outcome = K_count.get(0).label;//结果赋值
System.out.println(" 预测数字 "+outcome+" (K_value = "+K_value+")");
//打印预测结果
K_count.clear();//释放列表
return label == outcome;//返回布尔类型,真--标签正确,假-标签错误
}
}
class Digits{
//数据处理类
ArrayList<File[]> files = new ArrayList<File[]>();
ArrayList<Unit> ulist = new ArrayList<Unit>();
//便于处理文件的阅读、文件信息的储存
Digits(File c){
files.add(c.listFiles((f,name) -> name.contains("0_")));
files.add(c.listFiles((f,name) -> name.contains("1_")));
files.add(c.listFiles((f,name) -> name.contains("2_")));
files.add(c.listFiles((f,name) -> name.contains("3_")));
files.add(c.listFiles((f,name) -> name.contains("4_")));
files.add(c.listFiles((f,name) -> name.contains("5_")));
files.add(c.listFiles((f,name) -> name.contains("6_")));
files.add(c.listFiles((f,name) -> name.contains("7_")));
files.add(c.listFiles((f,name) -> name.contains("8_")));
files.add(c.listFiles((f,name) -> name.contains("9_")));
//使用lambda表达式,获取文件夹内部含有特定字段的文件,并放入同一个文件类数组中
read();//调用方法
}
void read() {
for(int num = 0; num <= 9; num++) {/*循环10个数字*/
for(int i = 0; i < files.get(num).length ; i++) {/*i循环每个数字的约200个文件*/
try {
BufferedReader br = new BufferedReader(new FileReader(files.get(num)[i]));
ArrayList<String> temp = new ArrayList<String>();
//32行暂持存储与temp
for(int j = 0; j < 32 ; j++) {/*j循环每个文件的32行*/
String read = null;//非空则写入
while( (read = br.readLine()) != null)
temp.add(read);
}
ulist.add(new Unit(num,temp));
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
public class Project4 {
public static void main(String[] args) {
File sa1 = new File("D:\\javatest\\trainingDigits");
Digits train = new Digits(sa1);
//建立训练类,下为测试类
File sa2 = new File("D:\\javatest\\testDigits");
Digits test = new Digits(sa2);
KnnNumber K = new KnnNumber();
int counter = 0;//计算正确标签数目
for(int test_target = 0 ; test_target < test.ulist.size() ;test_target++) {
/*遍历所有 测试用例*/
for(int i = 0; i < train.ulist.size();i++) {
/*遍历所有 训练数据*/
int sum = 0;//欧氏距离
int j = 0;//用于对测试用例行的循环计数
for(String a : train.ulist.get(i).line32) {
for(int l=0;l<32;l++) //l循环的是单行字符
if(a.toCharArray()[l] != test.ulist.get(test_target).line32.get(j).toCharArray()[l] )sum++;
j++;
}
train.ulist.get(i).temp_rank = sum;//距离赋值便于排序
}
train.ulist.sort((Unit u1,Unit u2) -> u1.temp_rank.compareTo(u2.temp_rank) );
/*排序并传递*/
System.out.print("测试用例 "+test_target+", 实际数字 "+test.ulist.get(test_target).label);
if( K.check(test.ulist.get(test_target).label, train) )counter++;
}
double successRate = (double)counter/test.ulist.size()*100;//计算正确率
System.out.println("counter = "+counter+" 数字识别正确率 = "+successRate+"%");
}
}
/*OutPut
测试用例 0, 实际数字 0 预测数字 0 (K_value = 3)
测试用例 1, 实际数字 0 预测数字 0 (K_value = 3)
……
……
……
测试用例 944, 实际数字 9 预测数字 9 (K_value = 3)
测试用例 945, 实际数字 9 预测数字 9 (K_value = 3)
counter = 934 数字识别正确率 = 98.73150105708245%
*/
附件(百度云)
DigitsForTestAndTrain.zip 提取码:LINK
二、K-means聚类算法
K-means算法是经典的聚类算法,其基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。假设要把样本集分为K个类别,算法描述如下:
(1)适当选择k个类的初始中心
(2)在第I次迭代中,对任意一个样本,求其到K个中心的距离,将该样本归到距离最短的中心所在的类
(3)利用均值方式更新该类的中心值
(4)对于所有的K个聚类中心,如果利用(2)(3)的迭代法更新后,值保持基本不变,则迭代结束,否则继续迭代。
要求用java编写K-means算法(k值可以自己设定),根据花的属性对数据集Iris Data Set进行聚类,并将聚类结果(sepal length,sepal width,petal length,petal width ,cluster label)打印至cluster.txt文件。iris数据包括四个属性:sepal length花萼长度,sepal width花萼宽度,petal length花瓣长度,petal width花瓣宽度。其中第五个值表示该样本属于哪一个类。Iris.data 可以用写字板打开。
注意:样本点间的距离直接用向量的欧氏距离。
思路
该题目考察了文件的读取,写入,double类型的计算,以及K-means聚类算法。
先从data文件中使用Bufferedreader类逐行读取数据,获得花朵的四个特征值,并且记录花朵的真实种类,使用1,2,3值代替名字,存储于Flower类的一个对象中,类的成员还包括sortkind表示使用K-means算法所分的类。使用ArrayList容纳。
进而计算距离既定的三个中心的距离,进行排序,就近决定类别,最后调用Update方法进行中心的更新。
运行结果出来后,有约25个错误,适当修改后余下13个错误分类,尝试数次始终不能低于13个,但是13个错误都出现在第三类,所以考虑增加一个聚类中心。
最终,在四个聚类中心的分类下,仅仅只有一个分类错误,分类效果非常好。
源码
package java_exp2022;
import java.io.*;
import java.util.*;
class Flower{
double sepal_length = 0; //花萼长度
double sepal_width = 0; //花萼宽度
double petal_length = 0; //花瓣长度
double petal_width = 0; //花瓣宽度
int realkind = -1;
int sortkind = -1;
Flower(String s){
char[] a = s.toCharArray();
sepal_length = (double)(a[0]-'0') +(double)(a[2]-'0')/10.0;
sepal_width = (double)(a[4]-'0') +(double)(a[6]-'0')/10.0;
petal_length = (double)(a[8]-'0') +(double)(a[10]-'0')/10.0;
petal_width = (double)(a[12]-'0') +(double)(a[14]-'0')/10.0;
if(s.contains("Iris-setosa"))realkind = 1;
if(s.contains("Iris-versicolor"))realkind = 2;
if(s.contains("Iris-virginica"))realkind = 3;
}
double Distance(Flower r1) {
double d = 0;
d = (r1.petal_length-this.petal_length)*(r1.petal_length-this.petal_length)
+(r1.petal_width-this.petal_width)*(r1.petal_width-this.petal_width)
+(r1.sepal_length-this.sepal_length)*(r1.sepal_length-this.sepal_length)
+(r1.sepal_width-this.sepal_width)*(r1.sepal_width-this.sepal_width);
return Math.sqrt(d);
}
int FlowerSort(ArrayList<Flower> K_means) {
int sortkind = -1;
double d1=this.Distance(K_means.get(0));
double d2=this.Distance(K_means.get(1));
double d3=this.Distance(K_means.get(2));
double d4=this.Distance(K_means.get(3));
if (d1<=d2 && d1<=d3 && d1<=d4)sortkind = 1;
else if(d2< d1 && d2<=d3 && d2<=d4)sortkind = 2;
else if(d3< d1 && d3< d2 && d3<=d4)sortkind = 3;
else if(d4< d1 && d4< d2 && d4< d3)sortkind = 4;
return sortkind;
}
void Update(ArrayList<Flower> K_means,int sortkind) {
double[] feature = {0,0,0,0};
int counter = 0;
for(Flower a : K_means) {
if(a.sortkind==sortkind) {
counter++;
feature[0] += a.petal_length;
feature[1] += a.petal_width;
feature[2] += a.sepal_length;
feature[3] += a.sepal_width;
}
}
K_means.get(sortkind-1).petal_length=feature[0]/counter;
K_means.get(sortkind-1).petal_width=feature[1]/counter;
K_means.get(sortkind-1).sepal_length=feature[2]/counter;
K_means.get(sortkind-1).sepal_width=feature[3]/counter;
}
}
public class Project5 {
public static void main(String[] args) {
File read = new File("D:\\javatest\\iris.data");
File write = new File("D:\\javatest\\cluster.txt");
try {
write.createNewFile();
} catch (IOException e1) {
e1.printStackTrace();
}
ArrayList<Flower> vase = new ArrayList<Flower>();
String temp = null;
try {
BufferedReader br = new BufferedReader(new FileReader(read));
while( ( temp = br.readLine()).contains("Iris") ) {
vase.add(new Flower(temp));
}
br.close();
} catch (IOException e) {
e.printStackTrace();
}
/*预置中心*/
ArrayList<Flower> K_means = new ArrayList<Flower>();
K_means.add(new Flower("5.0,3.4,1.5,0.2"));
K_means.get(0).sortkind=1;
K_means.add(new Flower("5.9,2.7,4.3,1.3"));
K_means.get(1).sortkind=2;
K_means.add(new Flower("6.0,3.1,6.0,2.1"));
K_means.get(2).sortkind=3;
K_means.add(new Flower("6.2,2.8,6.0,2.0"));
K_means.get(3).sortkind=4;
System.out.println("初始中心");
for(Flower a : K_means) {
System.out.print(a.sepal_length+",");
System.out.print(a.sepal_width+",");
System.out.print(a.petal_length+",");
System.out.print(a.petal_width+", kind = ");
System.out.print(a.sortkind+"\n");
}
int wrong = 0;//java中传递列表、数组给方法(函数)效果和C大致相同
int o = 0;
System.out.println("开始迭代\n");
String s = null;
for(Flower a : vase) {
a.sortkind = a.FlowerSort(K_means);
K_means.add(a);
a.Update(K_means, a.sortkind);
if(o==7) {
s="\n";
o-=7;
}
else {
s=" ";
o++;
}
System.out.print(a.realkind+"--"+a.sortkind+s);
if(a.realkind!=a.sortkind)wrong++;
if(a.realkind==3 && a.sortkind==4)wrong--;
}
System.out.println("Wrong times = "+wrong);
try {
FileWriter wr = new FileWriter(write);
for(int i = 0;i < 4 ;i ++) {
Flower a = K_means.get(i);
if(i==0)wr.write("Iris-setosa ");
if(i==1)wr.write("Iris-versicolor ");
if(i==2)wr.write("Iris-virginica-I ");
if(i==3)wr.write("Iris-virginica-II ");
String f1 = String.format("%.2f", a.sepal_length);
a.sepal_length = Double.parseDouble(f1);
wr.write("("+a.sepal_length);
String f2 = String.format("%.2f", a.sepal_width);
a.sepal_width = Double.parseDouble(f2);
wr.write(", "+a.sepal_width);
String f3 = String.format("%.2f", a.petal_length);
a.petal_length = Double.parseDouble(f3);
wr.write(", "+a.petal_length);
String f4 = String.format("%.2f", a.petal_width);
a.petal_width = Double.parseDouble(f4);
wr.write(", "+a.petal_width+")\n");
}
wr.close();
} catch (IOException e) {
e.printStackTrace();
}
System.out.print("\n结果文件在 D:\\javatest\\cluster.txt ");
}
}
/*OutPut
初始中心
5.0,3.4,1.5,0.2, kind = 1
5.9,2.7,4.3,1.3, kind = 2
6.0,3.1,6.0,2.1, kind = 3
6.2,2.8,6.0,2.0, kind = 4
开始迭代
1--1 1--1 1--1 1--1 1--1 1--1 1--1 1--1
1--1 1--1 1--1 1--1 1--1 1--1 1--1 1--1
1--1 1--1 1--1 1--1 1--1 1--1 1--1 1--1
1--1 1--1 1--1 1--1 1--1 1--1 1--1 1--1
1--1 1--1 1--1 1--1 1--1 1--1 1--1 1--1
1--1 1--1 1--1 1--1 1--1 1--1 1--1 1--1
1--1 1--1 2--2 2--2 2--2 2--2 2--2 2--2
2--2 2--2 2--2 2--2 2--2 2--2 2--2 2--2
2--2 2--2 2--2 2--2 2--2 2--2 2--2 2--2
2--2 2--2 2--2 2--2 2--2 2--2 2--2 2--2
2--2 2--2 2--2 2--2 2--2 2--2 2--2 2--2
2--2 2--2 2--2 2--2 2--2 2--2 2--2 2--2
2--2 2--2 2--2 2--2 3--3 3--4 3--3 3--4
3--3 3--3 3--2 3--3 3--3 3--3 3--4 3--4
3--3 3--4 3--4 3--4 3--4 3--3 3--3 3--4
3--3 3--4 3--3 3--4 3--3 3--3 3--4 3--4
3--4 3--3 3--3 3--3 3--4 3--4 3--4 3--3
3--4 3--4 3--4 3--3 3--3 3--4 3--4 3--3
3--3 3--4 3--4 3--4 3--4 3--4 Wrong times = 1
结果文件在 D:\javatest\cluster.txt
*/
附件(百度云)
iris.data 提取码:LINK