0.k近邻算法
刚接触java,并且在学习机器学习的相关算法,knn又非常的易于实现,于是就有了这个小系统。
1.knn算法简介:
存在一个样本数据集合,也称为训练样本集,并且样本集中的每一个数据都有标签,即我们知道样本集中的每一个数据的特征和对应的类型。当输入没有标签的新的数据的时候,将新的数据集的每一个特征和样本集中的每一个数据的对应的特征进行比较(计算两个样本的特征之间的距离),然后提取样本集中和输入的新数据特征最相似的数据的类的标签,通常我们只关心前k个最相似的数据,这就是k近算法中的k的出处。一般来说,我们只选择样本数据集中的前k最相似的数据,然后选择k个最相似的数据集中出现次数最多的作为新数据的分类。
2.该程序的功能主要有如下几个,
功能1:可以在面板上手写输入数字
功能2:可以对特定的区域进行截屏,因为要获取用户手写的数字,保存为图像,然后使用算法进行分析
功能3:可以对图片进行缩放,要保证图片的大小(维度)要和数据集中的大小一样。
功能4:可以将彩色图片转化为二值图片
功能5:对图片中的手写数字使用KNN算法进行识别,也可以在测试集上计算算法的准确性。
(演示)
功能1的实现代码:手写板
创建一个JPane类的子类,通过监听mouseDragged事件,调用graphics来实现手写板的功能。
class Board extends JPanel implements MouseMotionListener {
final private int boardWidth = 320;
final private int boardHeight = 320;
final private int boardX = 1;
final private int boardY = 1;
private int pencilWidth = 40;
public void paint(Graphics graphics) {
super.paint(graphics);
graphics.setColor(Color.BLACK);
graphics.draw3DRect(this.boardX - 1, this.boardY - 1, this.boardWidth + 1, this.boardHeight + 1, true);
graphics.setColor(Color.WHITE);
graphics.fill3DRect(this.boardX, this.boardY, this.boardWidth, this.boardHeight, true);
}
@Override
public void mouseDragged(MouseEvent e) {
// TODO Auto-generated method stub
Graphics graphics = this.getGraphics();
if (e.getX() > 1 && e.getX() < boardWidth - this.pencilWidth && e.getY() > 1
&& e.getY() < boardHeight - pencilWidth)
graphics.fillOval(e.getX(), e.getY(), pencilWidth, pencilWidth);
}
@Override
public void mouseMoved(MouseEvent e) {
// TODO Auto-generated method stub
}
}
功能2的实现:可以对特定的区域进行截屏
class ScreenShot {
private int startX;
private int startY;
private int width;
private int height;
private String saveTo;
public ScreenShot(int startX, int startY, int width, int height, String filename) {
this.startX = startX;//截取的起始x坐标
this.startY = startY;//截取的起始y坐标
this.width = width; //截取的宽度
this.height = height;//截取的高度
this.saveTo = ".\\" + filename + ".png";//图片的保存位置
}
public void capture() {
File file = new File(saveTo);
try {
BufferedImage bufferedImage = (new Robot())
.createScreenCapture(new Rectangle(startX, startY, width, height));
ImageIO.write(bufferedImage, "png", file);
System.out.println("capture image has finish...");
} catch (AWTException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
功能3:可以对图片进行缩放,要保证图片的大小(维度)要和数据集中的大小一样。
class ZoomImage {
private String filename;
private float scaling;
public ZoomImage(String filename, float scaling) {
this.filename = filename;//scaling为缩放比例,在这里是缩小的比例
this.scaling = scaling;
}
public void zoom() {
File file = new File(this.filename);
try {
BufferedImage bufferedImage1 = ImageIO.read(new File(filename));
BufferedImage bufferedImage2 = new BufferedImage((int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), BufferedImage.TYPE_INT_BGR);
Graphics graphics = bufferedImage2.createGraphics();
graphics.drawImage(bufferedImage1, 0, 0, (int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), null);
ImageIO.write(bufferedImage2, "png", new File(".\\zoominMaggie.png"));
System.out.println("image has been zoomed...");
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
功能4:可以将彩色图片转化为二值图片
class RGB2binary {
private String filename;
private short[] userInputDigit = new short[32 * 32];
public short[] getUserInputDigit() {
return this.userInputDigit;
}
public RGB2binary(String filename) {
this.filename = filename;
}
public void rgb2binary() {
System.out.println(this.filename);
File file = new File(this.filename);
try {
BufferedImage bufferedImage = ImageIO.read(file);
int startX = bufferedImage.getMinX();
int startY = bufferedImage.getMinY();
int width = bufferedImage.getWidth();
int height = bufferedImage.getHeight();
System.out.println("x = " + startX + " y = " + startY + " width = " + width + " height = " + height);
for (int i = startX; i < width; i++) {
for (int j = startY; j < height; j++) {
int pixel = bufferedImage.getRGB(j, i);
int r = (pixel & 0xff0000) >> 16;//得到该像素点的R值
int g = (pixel & 0xff00) >> 8;
int b = (pixel & 0xff);
float gray = r * 0.3f + g * 0.59f + b * 0.11f;//灰度变为二值的计算公式
if (gray > 128) {
System.out.print(0 + "");
userInputDigit[i * width + j] = 0;
} else {
System.out.print(1 + "");
userInputDigit[i * width + j] = 1;
}
}
System.out.println();
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
功能5:对图片中的手写数字使用KNN算法进行识别,也可以在测试集上计算算法的准确性。
class Knn {
private int featureSize = 32 * 32;
String trainingSetDir = "./trainingDigits";
String testSetDir = "./testDigits1";
private int trainingSetSize;
private int testSetSize;
private short[][] trainingData = null;
private short[] trainintSetLabel = null;
private short[][] testData = null;
private short[] testSetLabel = null;
public Knn() {
}
//读取训练集
public void readTrainingSet() {
File path = new File(trainingSetDir);
File files[] = path.listFiles();
System.out.println("total file number: " + files.length);
this.trainingSetSize = files.length;
trainingData = new short[trainingSetSize][32 * 32];
trainintSetLabel = new short[trainingSetSize];
int fileCount = 0;
for (File file : files) {
String[] filename = file.getName().split("_");
trainintSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
int lines = 0;
char buff[] = new char[32 + 2]; //为什么要+2:因为要读取文件末尾的换行和回车
int count = 0;
try {
FileReader fileReader = new FileReader(file);
while( -1 != (count = fileReader.read(buff)) ){
for( int i = 0; i < 32; i++ )
trainingData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileReader.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
fileCount++;
}
}
//读取测试集
public void readTestSet()
{
File path = new File(testSetDir);
File[] files = path.listFiles();
System.out.println("total number of test file" + files.length);
this.testSetSize = files.length;
testData = new short[this.testSetSize][32 * 32];
testSetLabel = new short[this.testSetSize];
int fileCount = 0;
for( File file : files )
{
String[] filename = file.getName().split("_");
testSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
try {
FileReader fileReader = new FileReader(file);
int count = 0;
int lines = 0;
char buff[] = new char[32 + 2];
while( -1 != (count = fileReader.read(buff)) )
{
for( int i = 0; i < 32; i++ )
testData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileReader.close();
fileCount++;
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
//@feature,待判断的实例的特征向量,
//@k,即为knn算法中的k
//返回分类的结果
public int knn(short[] feature, int k)
{
double[] distances = new double[this.trainingSetSize];
for( int i = 0; i < trainingSetSize; i++ )
distances[i] = calculateDistance(feature, trainingData[i]);
int[] argDistance = this.arg_sort(distances);
HashMap<Short, Integer> vote = new HashMap<>();
for( int i = 0; i < k; i++ )
{
if ( null == vote.get(trainintSetLabel[argDistance[i]]) )
vote.put(trainintSetLabel[argDistance[i]], 1);
else
{
int score = vote.get(trainintSetLabel[argDistance[i]]) + 1;
vote.put(trainintSetLabel[argDistance[i]], score);
}
}
int result = 0;
int maxVote = 0;
for( short key : vote.keySet() )
{
if( maxVote < vote.get(key) )
{
result = key;
maxVote = vote.get(key);
}
}
return result;
}
//在测试集上计算该算法的准确性
public double knnPrecise()
{
System.out.println("reading trainingSet...");
this.readTrainingSet();
System.out.println("reading trainingSet over");
System.out.println("reading testSet...");
this.readTestSet();
System.out.println("reading testSet end");
int success = 0;
for( int i = 0; i < testSetSize; i++ )
if( testSetLabel[i] == knn(testData[i], 3) )
success++;
return (double)success/testSetSize;
}
public double calculateDistance(short[] sequcence1, short[] sequence2)
{
int distance = 0;
for( int i = 0; i < sequcence1.length; i++ )
distance += (sequcence1[i] - sequence2[i]) * (sequcence1[i] - sequence2[i]);
return Math.sqrt(distance);
}
//返回的是sequence升序的下标序列
public int[] arg_sort(double[] sequence)
{
double[] sequence1 = sequence.clone();
int[] indexOfSequence = new int[sequence.length];
for( int i = 0; i < sequence1.length; i++ )
indexOfSequence[i] = i;
double minValue, tempD;
int minIndex,tempI;
for( int i = 0; i < sequence1.length - 1; i++ )
{
minValue = sequence1[i];
minIndex = i;
for( int j = i + 1; j < sequence1.length; j++ )
{
if( sequence1[j] < minValue )
{
minValue = sequence1[j];
minIndex = j;
}
}
if( i != minIndex )
{
tempD = sequence1[minIndex];
tempI = indexOfSequence[minIndex];
sequence1[minIndex] = sequence1[i];
indexOfSequence[minIndex] = indexOfSequence[i];
sequence1[i] = tempD;
indexOfSequence[i] = tempI;
}
}
return indexOfSequence;
}
public int getTrainingSetSize() {
return trainingSetSize;
}
public int getTestSetSize() {
return testSetSize;
}
}
3.结果:
在测试集上的准确性很高,但是实际应用中却远没有那么高。
完整代码
import java.awt.AWTException;
import java.awt.Color;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.Rectangle;
import java.awt.Robot;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.awt.event.MouseMotionListener;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import javax.imageio.ImageIO;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JTextArea;
public class Recognition extends JFrame implements ActionListener {
final private int windowWidth = 493;
final private int windowHeight = 380;
final private int windowX = 100;
final private int windowY = 100;
Board board = null;
JButton reWriteButton = null;
JButton recognitionButton = null;
JButton testButton = null;
JTextArea showResult = null;
private int contentPaneX;
private int contentPaneY;
public Recognition() {
board = new Board();
this.setLayout(null);
this.add(board);
board.setBounds(8, 8, 332, 332);
board.addMouseMotionListener(board);
reWriteButton = new JButton("Rewrite");
this.add(reWriteButton);
reWriteButton.setBounds(340, 10, 130, 30);
reWriteButton.addActionListener(this);
recognitionButton = new JButton("Recognition");
this.add(recognitionButton);
recognitionButton.setBounds(340, 40, 130, 30);
recognitionButton.addActionListener(this);
testButton = new JButton("testPrecise");
this.add(testButton);
testButton.setBounds(340, 80, 130, 30);
testButton.addActionListener(this);
showResult = new JTextArea();
showResult.setOpaque(true);
showResult.setBackground(Color.CYAN);
showResult.setForeground(Color.BLACK);
showResult.setFont(new Font("微软雅黑", Font.BOLD, 12));
showResult.setLineWrap(true);
this.add(showResult);
showResult.setBounds(340, 180, 130, 150);
showResult.setVisible(false);
this.setTitle("HandWriting Recognition");
this.setSize(windowWidth, windowHeight);
this.setLocation(windowX, windowY);
this.setVisible(true);
this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
}
public static void main(String[] args) {
// TODO Auto-generated method stub
Recognition recognition = new Recognition();
}
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
if (e.getSource() == reWriteButton) {
repaint();
} else if (e.getSource() == recognitionButton) {
this.contentPaneX = (int) this.getContentPane().getLocationOnScreen().getX();
this.contentPaneY = (int) this.getContentPane().getLocationOnScreen().getY();
ScreenShot screenShot = new ScreenShot(contentPaneX + 9, contentPaneY + 9, 320, 320, "maggie");
screenShot.capture();
ZoomImage zoomImage = new ZoomImage("./maggie.png", 0.1f);
zoomImage.zoom();
RGB2binary rgb2binary = new RGB2binary("./zoominMaggie.png");
rgb2binary.rgb2binary();
short[] userInput = rgb2binary.getUserInputDigit();
Knn knn = new Knn();
System.out.println("reading trainingSet...");
knn.readTrainingSet();
System.out.println("reading trainingSet over");
int recognitionResult = knn.knn(userInput, 3);
System.out.println("recognitionResult:"+ recognitionResult);
showResult.setText("Your input is \r\n" + String.valueOf(recognitionResult));
showResult.setVisible(true);
} else if ( e.getSource() == testButton ){
Knn knn = new Knn();
double precise = knn.knnPrecise();
String string = "Training Set Size is :\r\n" + knn.getTrainingSetSize() + "\r\nTest Set Size is :\r\n" + knn.getTestSetSize() + "\r\nAccury is \r\n" + String.valueOf(precise);
showResult.setText(string);
showResult.setVisible(true);
}
}
}
class Board extends JPanel implements MouseMotionListener {
final private int boardWidth = 320;
final private int boardHeight = 320;
final private int boardX = 1;
final private int boardY = 1;
private int pencilWidth = 40;
public void paint(Graphics graphics) {
super.paint(graphics);
graphics.setColor(Color.BLACK);
graphics.draw3DRect(this.boardX - 1, this.boardY - 1, this.boardWidth + 1, this.boardHeight + 1, true);
graphics.setColor(Color.WHITE);
graphics.fill3DRect(this.boardX, this.boardY, this.boardWidth, this.boardHeight, true);
}
@Override
public void mouseDragged(MouseEvent e) {
// TODO Auto-generated method stub
Graphics graphics = this.getGraphics();
if (e.getX() > 1 && e.getX() < boardWidth - this.pencilWidth && e.getY() > 1
&& e.getY() < boardHeight - pencilWidth)
graphics.fillOval(e.getX(), e.getY(), pencilWidth, pencilWidth);
}
@Override
public void mouseMoved(MouseEvent e) {
// TODO Auto-generated method stub
}
}
class ScreenShot {
private int startX;
private int startY;
private int width;
private int height;
private String saveTo;
public ScreenShot(int startX, int startY, int width, int height, String filename) {
this.startX = startX;
this.startY = startY;
this.width = width;
this.height = height;
this.saveTo = ".\\" + filename + ".png";
}
public void capture() {
File file = new File(saveTo);
try {
BufferedImage bufferedImage = (new Robot())
.createScreenCapture(new Rectangle(startX, startY, width, height));
ImageIO.write(bufferedImage, "png", file);
System.out.println("capture image has finish...");
} catch (AWTException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
class ZoomImage {
private String filename;
private float scaling;
public ZoomImage(String filename, float scaling) {
this.filename = filename;
this.scaling = scaling;
}
public void zoom() {
File file = new File(this.filename);
try {
BufferedImage bufferedImage1 = ImageIO.read(new File(filename));
BufferedImage bufferedImage2 = new BufferedImage((int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), BufferedImage.TYPE_INT_BGR);
Graphics graphics = bufferedImage2.createGraphics();
graphics.drawImage(bufferedImage1, 0, 0, (int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), null);
ImageIO.write(bufferedImage2, "png", new File(".\\zoominMaggie.png"));
System.out.println("image has been zoomed...");
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
class RGB2binary {
private String filename;
private short[] userInputDigit = new short[32 * 32];
public short[] getUserInputDigit() {
return this.userInputDigit;
}
public RGB2binary(String filename) {
this.filename = filename;
}
public void rgb2binary() {
System.out.println(this.filename);
File file = new File(this.filename);
try {
BufferedImage bufferedImage = ImageIO.read(file);
int startX = bufferedImage.getMinX();
int startY = bufferedImage.getMinY();
int width = bufferedImage.getWidth();
int height = bufferedImage.getHeight();
System.out.println("x = " + startX + " y = " + startY + " width = " + width + " height = " + height);
for (int i = startX; i < width; i++) {
for (int j = startY; j < height; j++) {
int pixel = bufferedImage.getRGB(j, i);
int r = (pixel & 0xff0000) >> 16;
int g = (pixel & 0xff00) >> 8;
int b = (pixel & 0xff);
float gray = r * 0.3f + g * 0.59f + b * 0.11f;
if (gray > 128) {
System.out.print(0 + "");
userInputDigit[i * width + j] = 0;
} else {
System.out.print(1 + "");
userInputDigit[i * width + j] = 1;
}
}
System.out.println();
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
class Knn {
private int featureSize = 32 * 32;
String trainingSetDir = "./trainingDigits";
String testSetDir = "./testDigits1";
private int trainingSetSize;
private int testSetSize;
private short[][] trainingData = null;
private short[] trainintSetLabel = null;
private short[][] testData = null;
private short[] testSetLabel = null;
public Knn() {
}
//读取训练集
public void readTrainingSet() {
File path = new File(trainingSetDir);
File files[] = path.listFiles();
System.out.println("total file number: " + files.length);
this.trainingSetSize = files.length;
trainingData = new short[trainingSetSize][32 * 32];
trainintSetLabel = new short[trainingSetSize];
int fileCount = 0;
for (File file : files) {
String[] filename = file.getName().split("_");
trainintSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
int lines = 0;
char buff[] = new char[32 + 2]; //为什么要+2:因为要读取文件末尾的换行和回车
int count = 0;
try {
FileReader fileReader = new FileReader(file);
while( -1 != (count = fileReader.read(buff)) ){
for( int i = 0; i < 32; i++ )
trainingData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileReader.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
fileCount++;
}
}
//读取测试集
public void readTestSet()
{
File path = new File(testSetDir);
File[] files = path.listFiles();
System.out.println("total number of test file" + files.length);
this.testSetSize = files.length;
testData = new short[this.testSetSize][32 * 32];
testSetLabel = new short[this.testSetSize];
int fileCount = 0;
for( File file : files )
{
String[] filename = file.getName().split("_");
testSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
try {
FileReader fileReader = new FileReader(file);
int count = 0;
int lines = 0;
char buff[] = new char[32 + 2];
while( -1 != (count = fileReader.read(buff)) )
{
for( int i = 0; i < 32; i++ )
testData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileReader.close();
fileCount++;
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
//@feature,待判断的实例的特征向量,
//@k,即为knn算法中的k
//返回分类的结果
public int knn(short[] feature, int k)
{
double[] distances = new double[this.trainingSetSize];
for( int i = 0; i < trainingSetSize; i++ )
distances[i] = calculateDistance(feature, trainingData[i]);
int[] argDistance = this.arg_sort(distances);
HashMap<Short, Integer> vote = new HashMap<>();
for( int i = 0; i < k; i++ )
{
if ( null == vote.get(trainintSetLabel[argDistance[i]]) )
vote.put(trainintSetLabel[argDistance[i]], 1);
else
{
int score = vote.get(trainintSetLabel[argDistance[i]]) + 1;
vote.put(trainintSetLabel[argDistance[i]], score);
}
}
int result = 0;
int maxVote = 0;
for( short key : vote.keySet() )
{
if( maxVote < vote.get(key) )
{
result = key;
maxVote = vote.get(key);
}
}
return result;
}
//在测试集上计算该算法的准确性
public double knnPrecise()
{
System.out.println("reading trainingSet...");
this.readTrainingSet();
System.out.println("reading trainingSet over");
System.out.println("reading testSet...");
this.readTestSet();
System.out.println("reading testSet end");
int success = 0;
for( int i = 0; i < testSetSize; i++ )
if( testSetLabel[i] == knn(testData[i], 3) )
success++;
return (double)success/testSetSize;
}
public double calculateDistance(short[] sequcence1, short[] sequence2)
{
int distance = 0;
for( int i = 0; i < sequcence1.length; i++ )
distance += (sequcence1[i] - sequence2[i]) * (sequcence1[i] - sequence2[i]);
return Math.sqrt(distance);
}
//返回的是sequence升序的下标序列
public int[] arg_sort(double[] sequence)
{
double[] sequence1 = sequence.clone();
int[] indexOfSequence = new int[sequence.length];
for( int i = 0; i < sequence1.length; i++ )
indexOfSequence[i] = i;
double minValue, tempD;
int minIndex,tempI;
for( int i = 0; i < sequence1.length - 1; i++ )
{
minValue = sequence1[i];
minIndex = i;
for( int j = i + 1; j < sequence1.length; j++ )
{
if( sequence1[j] < minValue )
{
minValue = sequence1[j];
minIndex = j;
}
}
if( i != minIndex )
{
tempD = sequence1[minIndex];
tempI = indexOfSequence[minIndex];
sequence1[minIndex] = sequence1[i];
indexOfSequence[minIndex] = indexOfSequence[i];
sequence1[i] = tempD;
indexOfSequence[i] = tempI;
}
}
return indexOfSequence;
}
public int getTrainingSetSize() {
return trainingSetSize;
}
public int getTestSetSize() {
return testSetSize;
}
}