面向对象技术与方法(JAVA)


  关键词:面向对象JAVA、机器学习、聚类算法、K近邻算法、手写数字识别


背景信息


  这是2022年春季学期,北京理工大学,计算机学院,软件工程,开设的一门必修课程,为期八周,共32课时,难度较大,教师课件使用英文讲授,自学的收获更多。

  成绩主要分为三部分,作业、实验、考试。本篇内容为实验题,共六个大题,其中两道难度非常高,以当时的水平写出来实属不易。其他题目更多的是了解面向对象语言的特点和语法。

  本篇内容是实验题中的难题。


2022面向对象JAVA实验

目录

 (点击跳转)

  1. K近邻算法与手写数字识别
  2. K-means聚类算法


一、K近邻算法与手写数字识别

  K近邻算法是分类数据最简单有效的算法,它采用基于实例的学习方法。简单地说,它采用测量不同样本之间距离的方法进行分类。它的工作原理是:存在一个样本数据集合,也称为训练样本集,并且样本集中的每个数据都有标签,即我们知道每个数据所属的分类。输入没有标签的新数据之后,将新数据的每个特征与样本集中数据的对应特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前K个最相似的数据,这就是K近邻算法中K的出处。最后,选择K个最相似数据中出现次数最多的分类。
  要求设计和实现类KnnNumber,构造一个使用K近邻分类器的手写识别系统,该系统可以识别数字0到9.需要识别的数字已经处理成具有相同的色彩和大小:使用文本格式表示的32像素*32像素黑白图像(0/1二值图像)。
  数据集在digits.zip中。其中目录trainingDigits包含了训练样本集,其中包含了大约2000个样本数据,每个数据的文件名表明了它的标签(0~9的某个数字),每个数字大约有200个样本数据;目录testDigits中包含了大约900个测试数据。请合理设计KnnNumber的数据成员和成员方法,以实现算法的各个步骤。将K取值为[3,9]之间的一个整数,找出分类准确率最高的K值。

思路

  根据要求,我们首先要实现两个文件夹所有数据的读取,训练数据与测试数据使用两个Digits类来读取与储存,其中Digits包含File[]的列表与 Unit列表,分别用于读取与存储。
  首先我们在主函数中建立两个文件标签sa1 sa2分别指向两个文件夹,继而建立对应的Digits成员,并调用构造器完成文件的读取。每个文件夹中包含0~9十种txt文本,每种数量不定,但是拥有共同特征,如0都包含“0_”这样的字段,通过.contain()方法筛选出所有0的文本文件,存入列表files中的一个文件数组中,重复十次,为了简化代码,获取文件目录的时候还使用了lambda表达式。读取文本文件使用了FileReader类并用Bufferedreader类包装,逐行读取字符串。
  train与test的数据都读取完毕后,从test中逐个地全部遍历train文件中所有的txt文件,遍历过程中对train中每个txt文件的rank成员赋值,即两个图片的欧氏距离,由于开方运算并不会影响距离的大小排序,故省去了开方这一步,也可以说这是汉明距离,对最终的结果不会影响。对test中的一个txt文件,train的每个txt文件已经得到rank值,对train中所有数据进行升序排序,选取K值,得到前K个文件,按照标签出现的频次决定test对象的标签,将此过程执行于所有的test对象,与实际的标签对比,计算出正确率;改变K值,找出最高正确率的K值。

收获

  该题目考察了文件目录的获取,文件的筛选,文件的读取,大量数据的存储与计算,以及K近邻算法的应用。
  首先使用两个Digits类来读取与储存训练数据和测试数据,Digits类的数据成员有包含File[]的列表与 Unit列表,分别用于读取与存储。继而在主函数中建立两个sa1 sa2(sample的前两个字母)对象分别指向两个目标文件夹,然后建立对应的Digits成员,并调用构造器完成文件的读取。
  两个文件夹中各包含0~9十种txt文本,数量不定,但是拥有共同特征,如0文件都包含 “0_” 这样的字段,通过使用字符串调用.contain()方法筛选出所有0的文本文件,存入列表files中的同一个文件数组中,重复十次,代码中还使用lambda表达式简化代码。
读取文本文件使用了Bufferedreader类,逐行读取字符串。
  train与test的数据都读取完毕后,从test中逐个地全部遍历train文件中所有的txt文件,遍历过程中对train中每个txt文件的rank成员赋值,即两个图片的欧氏距离,由于开方运算并不会影响距离的大小排序,故省去了开方这一步,也可以说这是汉明距离,对最终的结果不会影响。
  对test中的一个txt文件,train的每个txt文件已经得到rank值,对train中所有数据进行升序排序,选取K值,得到前K个文件,按照标签出现的频次决定test对象的标签,将此过程执行于所有的test对象,与实际的标签对比,计算出正确率;改变K值,找出最高正确率的K值。

源码

package java_exp2022;
import java.util.*;
import java.io.*;

/*由于欧氏距离的开平方计算并不会影响大小顺序,故本算法忽略*/
class Unit{
	//这是将每个数字文本文件(如 0_12.txt )转化为一个类Unit
	int label = -1;//标签初始化为-1,防止出错
	Integer temp_rank = 0;//这里是排序依据,根据题目需要,这里可以赋值欧氏距离
	ArrayList<String> line32 = new ArrayList<String>();
	/*建立一个列表容纳三十二行01数据,存储为String类*/
	Unit(int digit, ArrayList<String> lines){
		//使用构造器赋值
		label = digit ;
		for(String s : lines) 
			line32.add(s);
	}/*每个单元存入  数字 + 32行   */
}
class KnnCompare{//K近邻算法的计数单元
	int label = -1;
	Integer counter = 0;
	//使用对象类型便于调用.compareTo方法
	KnnCompare(int i){//构造器赋值
		label = i;
	}
}
class KnnNumber{
	ArrayList<KnnCompare> K_count = new ArrayList<KnnCompare>();
	//建立一个列表,方便对于K个数目的预测目标计数与排序
	boolean check(int label,Digits train) {
		K_count.add(new KnnCompare(0));
		K_count.add(new KnnCompare(1));
		K_count.add(new KnnCompare(2));
		K_count.add(new KnnCompare(3));
		K_count.add(new KnnCompare(4));
		K_count.add(new KnnCompare(5));
		K_count.add(new KnnCompare(6));
		K_count.add(new KnnCompare(7));
		K_count.add(new KnnCompare(8));
		K_count.add(new KnnCompare(9));
		//将对象添加入列表
		int outcome = -2;//最终输出的标签,初始化以防出错
		int K_value = 3;//K值
	    for(int i = 0; i < K_value ;i++) 
	    	K_count.get(train.ulist.get(i).label).counter++;
	    //对前K个标签预测进行计数
	    K_count.sort((KnnCompare k1,KnnCompare k2) -> k2.counter.compareTo(k1.counter)  );
	    //对列表进行降序排列,此处使用了lambda表达式
	    outcome = K_count.get(0).label;//结果赋值
	    System.out.println(" 预测数字 "+outcome+" (K_value = "+K_value+")");
	    //打印预测结果
	    K_count.clear();//释放列表
	    return label == outcome;//返回布尔类型,真--标签正确,假-标签错误
	}
	
}

class Digits{
	//数据处理类
	ArrayList<File[]> files = new ArrayList<File[]>();
	ArrayList<Unit> ulist = new ArrayList<Unit>();
	//便于处理文件的阅读、文件信息的储存
	Digits(File c){
		files.add(c.listFiles((f,name) -> name.contains("0_")));
		files.add(c.listFiles((f,name) -> name.contains("1_")));
		files.add(c.listFiles((f,name) -> name.contains("2_")));
		files.add(c.listFiles((f,name) -> name.contains("3_")));
		files.add(c.listFiles((f,name) -> name.contains("4_")));
		files.add(c.listFiles((f,name) -> name.contains("5_")));
		files.add(c.listFiles((f,name) -> name.contains("6_")));
		files.add(c.listFiles((f,name) -> name.contains("7_")));
		files.add(c.listFiles((f,name) -> name.contains("8_")));
		files.add(c.listFiles((f,name) -> name.contains("9_")));
		//使用lambda表达式,获取文件夹内部含有特定字段的文件,并放入同一个文件类数组中
		read();//调用方法
	}
    void read() {
    	for(int num = 0; num <= 9; num++) {/*循环10个数字*/
    		for(int i = 0; i < files.get(num).length ; i++) {/*i循环每个数字的约200个文件*/
    			try {
					BufferedReader br = new BufferedReader(new FileReader(files.get(num)[i]));
					ArrayList<String> temp = new ArrayList<String>();
					//32行暂持存储与temp
					for(int j = 0; j < 32 ; j++) {/*j循环每个文件的32行*/
				    	String read = null;//非空则写入
				    	while( (read = br.readLine()) != null)
				    		temp.add(read);
				    }
					ulist.add(new Unit(num,temp));
    			    br.close();
    			} catch (IOException e) {
					e.printStackTrace();
				}
    		}
    	}
    }
}

public class Project4 {
	public static void main(String[] args) {
		File sa1 = new File("D:\\javatest\\trainingDigits");
		Digits train = new Digits(sa1);
		//建立训练类,下为测试类
		File sa2 = new File("D:\\javatest\\testDigits");
		Digits test = new Digits(sa2);
		KnnNumber K = new KnnNumber();
		int counter = 0;//计算正确标签数目
		for(int test_target = 0 ; test_target < test.ulist.size()  ;test_target++) {
			/*遍历所有 测试用例*/
			for(int i = 0; i < train.ulist.size();i++) {
				/*遍历所有 训练数据*/
				int sum = 0;//欧氏距离
				int j = 0;//用于对测试用例行的循环计数
			    for(String a : train.ulist.get(i).line32) {
				    for(int l=0;l<32;l++) //l循环的是单行字符
				        if(a.toCharArray()[l] != test.ulist.get(test_target).line32.get(j).toCharArray()[l] )sum++;
		            j++;
				
			    }
				train.ulist.get(i).temp_rank = sum;//距离赋值便于排序
			}
			train.ulist.sort((Unit u1,Unit u2) -> u1.temp_rank.compareTo(u2.temp_rank) );
			/*排序并传递*/
			System.out.print("测试用例 "+test_target+",  实际数字 "+test.ulist.get(test_target).label);	
			if(   K.check(test.ulist.get(test_target).label, train)  )counter++;
		}
		double successRate = (double)counter/test.ulist.size()*100;//计算正确率
		System.out.println("counter = "+counter+" 数字识别正确率 = "+successRate+"%");
	}
}

/*OutPut
测试用例 0,  实际数字 0 预测数字 0 (K_value = 3)
测试用例 1,  实际数字 0 预测数字 0 (K_value = 3)
……
……
……
测试用例 944,  实际数字 9 预测数字 9 (K_value = 3)
测试用例 945,  实际数字 9 预测数字 9 (K_value = 3)
counter = 934 数字识别正确率 = 98.73150105708245%
*/

附件(百度云)

DigitsForTestAndTrain.zip 提取码:LINK


二、K-means聚类算法


  K-means算法是经典的聚类算法,其基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。假设要把样本集分为K个类别,算法描述如下:
(1)适当选择k个类的初始中心
(2)在第I次迭代中,对任意一个样本,求其到K个中心的距离,将该样本归到距离最短的中心所在的类
(3)利用均值方式更新该类的中心值
(4)对于所有的K个聚类中心,如果利用(2)(3)的迭代法更新后,值保持基本不变,则迭代结束,否则继续迭代。
  要求用java编写K-means算法(k值可以自己设定),根据花的属性对数据集Iris Data Set进行聚类,并将聚类结果(sepal length,sepal width,petal length,petal width ,cluster label)打印至cluster.txt文件。iris数据包括四个属性:sepal length花萼长度,sepal width花萼宽度,petal length花瓣长度,petal width花瓣宽度。其中第五个值表示该样本属于哪一个类。Iris.data 可以用写字板打开。
注意:样本点间的距离直接用向量的欧氏距离。

思路

  该题目考察了文件的读取,写入,double类型的计算,以及K-means聚类算法。
  先从data文件中使用Bufferedreader类逐行读取数据,获得花朵的四个特征值,并且记录花朵的真实种类,使用1,2,3值代替名字,存储于Flower类的一个对象中,类的成员还包括sortkind表示使用K-means算法所分的类。使用ArrayList容纳。
  进而计算距离既定的三个中心的距离,进行排序,就近决定类别,最后调用Update方法进行中心的更新。
  运行结果出来后,有约25个错误,适当修改后余下13个错误分类,尝试数次始终不能低于13个,但是13个错误都出现在第三类,所以考虑增加一个聚类中心。
  最终,在四个聚类中心的分类下,仅仅只有一个分类错误,分类效果非常好。

源码

package java_exp2022;
import java.io.*;
import java.util.*;

class Flower{
	double sepal_length = 0; //花萼长度
	double sepal_width = 0;  //花萼宽度
	double petal_length = 0; //花瓣长度
	double petal_width = 0;  //花瓣宽度
    int realkind = -1;
    int sortkind = -1;
    Flower(String s){
		char[] a = s.toCharArray();
		sepal_length = (double)(a[0]-'0') +(double)(a[2]-'0')/10.0;
		sepal_width = (double)(a[4]-'0') +(double)(a[6]-'0')/10.0;
		petal_length = (double)(a[8]-'0') +(double)(a[10]-'0')/10.0;
		petal_width = (double)(a[12]-'0') +(double)(a[14]-'0')/10.0;
		if(s.contains("Iris-setosa"))realkind = 1;
		if(s.contains("Iris-versicolor"))realkind = 2;
		if(s.contains("Iris-virginica"))realkind = 3;
	}
    double Distance(Flower r1) {
    	double d = 0;
    	d = (r1.petal_length-this.petal_length)*(r1.petal_length-this.petal_length)
    	   +(r1.petal_width-this.petal_width)*(r1.petal_width-this.petal_width)
    	   +(r1.sepal_length-this.sepal_length)*(r1.sepal_length-this.sepal_length)
    	   +(r1.sepal_width-this.sepal_width)*(r1.sepal_width-this.sepal_width);
    	
    	return Math.sqrt(d);
    }
    int FlowerSort(ArrayList<Flower> K_means) {
    	int sortkind = -1;
    	double d1=this.Distance(K_means.get(0));
    	double d2=this.Distance(K_means.get(1));
    	double d3=this.Distance(K_means.get(2));
    	double d4=this.Distance(K_means.get(3));
    	if     (d1<=d2 && d1<=d3 && d1<=d4)sortkind = 1;
    	else if(d2< d1 && d2<=d3 && d2<=d4)sortkind = 2;
    	else if(d3< d1 && d3< d2 && d3<=d4)sortkind = 3;
    	else if(d4< d1 && d4< d2 && d4< d3)sortkind = 4;
    	return sortkind;
    }
    void Update(ArrayList<Flower> K_means,int sortkind) {
    	double[] feature = {0,0,0,0};
		int counter = 0;
		for(Flower a : K_means) {
			if(a.sortkind==sortkind) {
				counter++;
				feature[0] += a.petal_length;
				feature[1] += a.petal_width;
				feature[2] += a.sepal_length;
				feature[3] += a.sepal_width;
			}
		}
		K_means.get(sortkind-1).petal_length=feature[0]/counter;
		K_means.get(sortkind-1).petal_width=feature[1]/counter;
		K_means.get(sortkind-1).sepal_length=feature[2]/counter;
		K_means.get(sortkind-1).sepal_width=feature[3]/counter;
    }
}
public class Project5 {
	public static void main(String[] args) {
		File read = new File("D:\\javatest\\iris.data");
		File write = new File("D:\\javatest\\cluster.txt");
		try {
			write.createNewFile();
		} catch (IOException e1) {
			e1.printStackTrace();
		}
		ArrayList<Flower> vase = new ArrayList<Flower>();
		String temp = null;
		try {
			BufferedReader  br = new BufferedReader(new FileReader(read));
			while( ( temp = br.readLine()).contains("Iris") ) {
				vase.add(new Flower(temp));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		/*预置中心*/
		ArrayList<Flower> K_means = new ArrayList<Flower>();
		K_means.add(new Flower("5.0,3.4,1.5,0.2"));
		K_means.get(0).sortkind=1;
		K_means.add(new Flower("5.9,2.7,4.3,1.3"));
		K_means.get(1).sortkind=2;
		K_means.add(new Flower("6.0,3.1,6.0,2.1"));
		K_means.get(2).sortkind=3;
		K_means.add(new Flower("6.2,2.8,6.0,2.0"));
		K_means.get(3).sortkind=4;
		
		System.out.println("初始中心");
		for(Flower a : K_means) {
			System.out.print(a.sepal_length+",");
			System.out.print(a.sepal_width+",");
			System.out.print(a.petal_length+",");
			System.out.print(a.petal_width+", kind = ");
			System.out.print(a.sortkind+"\n");
		}
		int wrong = 0;//java中传递列表、数组给方法(函数)效果和C大致相同
		int o = 0;
		System.out.println("开始迭代\n");
		String s = null;
		for(Flower a : vase) {
			a.sortkind = a.FlowerSort(K_means);
			K_means.add(a);
			a.Update(K_means, a.sortkind);
			
			if(o==7) {
				s="\n";
			    o-=7;
			}
			else {
				s="  ";
				o++;
			}
			System.out.print(a.realkind+"--"+a.sortkind+s);
			if(a.realkind!=a.sortkind)wrong++;
			if(a.realkind==3 && a.sortkind==4)wrong--;
		}
		System.out.println("Wrong times = "+wrong);
		try {
			FileWriter wr = new FileWriter(write);
			for(int i = 0;i < 4 ;i ++) {
				Flower a = K_means.get(i);
				if(i==0)wr.write("Iris-setosa ");
				if(i==1)wr.write("Iris-versicolor ");
				if(i==2)wr.write("Iris-virginica-I  ");
				if(i==3)wr.write("Iris-virginica-II ");
				
				String f1 = String.format("%.2f", a.sepal_length);
				a.sepal_length = Double.parseDouble(f1);
				wr.write("("+a.sepal_length);
				
				String f2 = String.format("%.2f", a.sepal_width);
				a.sepal_width = Double.parseDouble(f2);
				wr.write(", "+a.sepal_width);
				
				String f3 = String.format("%.2f", a.petal_length);
				a.petal_length = Double.parseDouble(f3);
				wr.write(", "+a.petal_length);
				
				String f4 = String.format("%.2f", a.petal_width);
				a.petal_width = Double.parseDouble(f4);
				wr.write(", "+a.petal_width+")\n");
			}
			wr.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		System.out.print("\n结果文件在 D:\\javatest\\cluster.txt ");	
	}
}
/*OutPut

初始中心
5.0,3.4,1.5,0.2, kind = 1
5.9,2.7,4.3,1.3, kind = 2
6.0,3.1,6.0,2.1, kind = 3
6.2,2.8,6.0,2.0, kind = 4
开始迭代

1--1  1--1  1--1  1--1  1--1  1--1  1--1  1--1
1--1  1--1  1--1  1--1  1--1  1--1  1--1  1--1
1--1  1--1  1--1  1--1  1--1  1--1  1--1  1--1
1--1  1--1  1--1  1--1  1--1  1--1  1--1  1--1
1--1  1--1  1--1  1--1  1--1  1--1  1--1  1--1
1--1  1--1  1--1  1--1  1--1  1--1  1--1  1--1
1--1  1--1  2--2  2--2  2--2  2--2  2--2  2--2
2--2  2--2  2--2  2--2  2--2  2--2  2--2  2--2
2--2  2--2  2--2  2--2  2--2  2--2  2--2  2--2
2--2  2--2  2--2  2--2  2--2  2--2  2--2  2--2
2--2  2--2  2--2  2--2  2--2  2--2  2--2  2--2
2--2  2--2  2--2  2--2  2--2  2--2  2--2  2--2
2--2  2--2  2--2  2--2  3--3  3--4  3--3  3--4
3--3  3--3  3--2  3--3  3--3  3--3  3--4  3--4
3--3  3--4  3--4  3--4  3--4  3--3  3--3  3--4
3--3  3--4  3--3  3--4  3--3  3--3  3--4  3--4
3--4  3--3  3--3  3--3  3--4  3--4  3--4  3--3
3--4  3--4  3--4  3--3  3--3  3--4  3--4  3--3
3--3  3--4  3--4  3--4  3--4  3--4  Wrong times = 1

结果文件在 D:\javatest\cluster.txt 

*/

附件(百度云)

iris.data 提取码:LINK