多层感知机的手写数字识别,迭代10次对训练集的正确率97

Main函数,在绘制完数字后,要点下确定按钮再去识别,重绘按钮自然是再次绘图

Java 手写重试机制 java手写数字识别_java

 

Java 手写重试机制 java手写数字识别_java_02

 训练自己的网络结构会替换之前训练的网络结构,没有写保存或者另存新网络模型。结果对训练集变现很好,对绘图的识别结果仍不是很理想。

package main;

import java.awt.Color;
import java.awt.Container;
import java.awt.Graphics;
import java.awt.Image;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Scanner;

import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JButton;
import javax.swing.JFileChooser;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JTextField;
import javax.swing.JTextPane;
import javax.swing.text.Style;
import javax.swing.text.StyleConstants;
import javax.swing.text.StyleContext;

import imageprocess.getimage;

import network.NetWork;
import network.traindata;

public class GUI extends JFrame {
	private JFrame jFrame;
	private BufferedImage img;//用于显示输入图片
	private JButton sure;//手写输入确定
	private JButton cancel;//手写输入确定
	private JButton recognition;//识别
	private JButton train;//训练自己的网络
	private JButton open;
	private JTextField result;
	private int[][] getmatrix=new int[28][28];
	private JTextPane imgtextarea;
	private JLabel imglabel;
	private static NetWork neunet;
	private JFileChooser choose;//选择文件
	//private int k=0;

	public GUI() {
		neunet=new NetWork(10,0.01,50,0.2,0.5);
		neunet.initNodes();
		jFrame=new JFrame("数字识别");
		jFrame.setBounds(0, 0, 765, 800);
		jFrame.setLayout(null);
		recognition=new JButton("识别结果");
		train=new JButton("训练");
		open=new JButton("打开图片");
		sure=new JButton("确定");
		cancel=new JButton("重绘");
		JPanel resultpanel = new JPanel();
		final mypanel panel = new mypanel();//新建画板
		Container contentPane = getContentPane();
		contentPane.setBounds(0, 0,350,350);
		contentPane.add(panel);
		jFrame.add(contentPane);
		JPanel draw=new JPanel();//画板桌布
		draw.setBounds(0, 0, 380,420);
		draw.setLayout(null);
		draw.setBackground(Color.lightGray);
		jFrame.add(draw);

		draw.add(sure);
		sure.setBounds(10, 370, 60, 30);
		draw.add(cancel);
		cancel.setBounds(130, 370, 60, 30);
		draw.add(train);
		train.setBounds(250, 370, 60, 30);
		open.setBounds(420, 320, 90, 30);

		recognition.setBounds(560, 320, 90, 30);
		imgtextarea=new JTextPane();
		Style style=new StyleContext().new NamedStyle();
		StyleConstants.setLineSpacing(style,-0.1f);
		StyleConstants.setFontSize(style, 7);
		StyleConstants.setBold(style, true);
		imgtextarea.setLogicalStyle(style);
		imglabel=new JLabel();
		imgtextarea.setSize(50, 70);
		imgtextarea.setBounds(300, 0, 200, 60);
		imglabel.setBounds(200,10,100,100);
		imgtextarea.setEditable(false);

		choose = new JFileChooser();
		choose.setCurrentDirectory(new File("."));

		resultpanel.add(imglabel);
		resultpanel.add(imgtextarea);
		result=new JTextField();
		result.setBounds(560, 360, 140, 50);
		result.setVisible(true);
		resultpanel.setBounds(381, 0, 350, 310); 
		resultpanel.setBackground(Color.gray);
		jFrame.add(resultpanel);
		jFrame.add(result);
		jFrame.add(recognition);
		jFrame.add(open);

		jFrame.setSize(761, 450);
		jFrame.setVisible(true);
		sure.addActionListener(new ActionListener() {

			@Override
			public void actionPerformed(ActionEvent actionevent) {
				// TODO Auto-generated method stub
				BufferedImage image=new BufferedImage(panel.getWidth(), panel.getHeight(), BufferedImage.TYPE_INT_RGB);
				Graphics gs=image.getGraphics();
				panel.paintAll(gs);
				gs.drawImage(image, 0, 0, panel.getWidth(), panel.getHeight(), null);
				try {
					ImageIO.write(image, "png", new File("./save.jpg"));
				} catch (IOException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
				}
				try {
					getmatrix=getimage.getMatirx(image);
					if(image.getHeight()>28||image.getWidth()>28){
						image=getimage.scale(image, 28, 28);}
					imglabel.setIcon(new ImageIcon((Image)image));//把图片作为icon显示
					
					for (int i = 0; i < getmatrix.length-1; i++) {
						for (int j = i+1; j < getmatrix.length; j++) {
							int temp=getmatrix[i][j];
							getmatrix[i][j]=getmatrix[j][i];
							getmatrix[j][i]=temp;
						}
						
					}
					imgtextarea.setText("");
					String s="";
					for(int i=0;i<getmatrix[0].length;i++) {
						for(int j=0;j<getmatrix.length;j++) {
							if(j==getmatrix[0].length-1) {
								s=s+getmatrix[j][i]+"\n";
								imgtextarea.setText(s);
							}
							else {
								s=s+getmatrix[j][i]+",";
								imgtextarea.setText(s);
							}
						}
					}		
				} catch (IOException e1) {
					// TODO Auto-generated catch block
					e1.printStackTrace();
				}
			}
		});
		recognition.addActionListener(new ActionListener() {
			public void actionPerformed(ActionEvent actionevent) {
				// TODO Auto-generated method stub
				int input[]=new int [784];
				for(int i=0;i<getmatrix[0].length;i++)
				{
					for(int j=0;j<getmatrix.length;j++) {
						input[i*getmatrix.length+j]=getmatrix[i][j];
					}
				}
				File file = new File("./parameter.txt");
				Scanner in = null;
				try {
					in = new Scanner(file);
				} catch (FileNotFoundException e1) {
					// TODO Auto-generated catch block
					e1.printStackTrace();
				}
				int iter = in.nextInt();
				double size=in.nextDouble();
				int inputSize     = in.nextInt();
				double emnue  = in.nextDouble();
				NetWork net=new NetWork(iter,size , inputSize, 0.05, emnue);
				net.initNodes();
				try {
					result.setText(net.recognize(input));
				} catch (FileNotFoundException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
				}					
			}
		});
		cancel.addActionListener(new ActionListener() {
			@Override
			public void actionPerformed(ActionEvent actionevent) {
				// TODO Auto-generated method stub
				panel.cleanAll();
			}
		});
		open.addActionListener(new ActionListener() {
			@Override
			public void actionPerformed(ActionEvent actionevent) {
				// TODO Auto-generated method stub
				int result = choose.showOpenDialog(null);
				if(result == JFileChooser.APPROVE_OPTION){
					String name = choose.getSelectedFile().getPath();
					try {
						img=ImageIO.read(new File(name));
						if(img.getWidth()>28||img.getHeight()>28)
						{img=getimage.scale(img, 28, 28);//缩放图片
						}
						imglabel.setIcon(new ImageIcon((Image)img));//把图片作为icon显示
						int a[]=traindata.SampleMatirx(img);
						for (int i = 0; i < a.length; i++) {
							getmatrix[i/28][i%28]=a[i];
						}
						String s=" ";
						for(int i=0;i<getmatrix[0].length;i++) {
							for(int j=0;j<getmatrix.length;j++) {
								if(j==getmatrix[0].length-1) {
									s=s+getmatrix[j][i]+"\n";
									imgtextarea.setText(s);
								}
								else {
									s=s+getmatrix[j][i]+",";
									imgtextarea.setText(s);
								}
							}
						}										

					} catch (IOException e1) {
						// TODO Auto-generated catch block
						e1.printStackTrace();
					}
				}
			}
		});
		train.addActionListener(new ActionListener() {

			@Override
			public void actionPerformed(ActionEvent actionevent) {
				// TODO Auto-generated method stub
				trainwin aTrainwin=new trainwin();
			}
		});
	}
	public static void main(String[] args) throws IOException {
		//traindata td=new traindata();//用于提取训练集
		//td.imagetomatrix();
		GUI gui=new GUI();
	}
}

获取数据集二值化矩阵保存用于训练,使用了40000张图像数据作为训练集

package network;

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;

import javax.imageio.ImageIO;

import imageprocess.binaryimage;
import imageprocess.trainbrainimage;

public class traindata {
	private static int sampleNumber=40000;

public static void imagetomatrix() throws IOException{
	int input[][]=new int[sampleNumber][784];
	try {
		for(int i=0;i<=9;i++) {
			for(int j=0;j<=3999;j++) {
				BufferedImage image;
				image=ImageIO.read(new File("./mnist/mnist_data/"+i+"."+j+".jpg"));
				input[j+i*4000]=SampleMatirx(image);
			}
		}
		File f=new File("./allInput.txt");
		if(!f.exists()){
			f.createNewFile();}
		else{
			FileOutputStream opf=new FileOutputStream("./allInput.txt");
			PrintStream s=new PrintStream(opf);
			for(int i=0;i<sampleNumber;i++){
				for(int j=0;j<784;j++){
						s.print(input[i][j]+",");
				}
				s.println();
				}
			}
	}catch(Exception e)
	{System.out.println(e);
	}
}
public static int[] SampleMatirx(BufferedImage image) throws IOException {
	// TODO Auto-generated method stub
	trainbrainimage bimg=new trainbrainimage();//创建二值图类,分别保存二值化后的图片及矩阵	
	bimg.brmatrix(image);//进行二值化,把矩阵、二值图保存到类里;
	int high=image.getHeight();
	int weigh=image.getWidth();	
	int a[]=new int[high*weigh];
	for (int i=0;i<high;i++) {
		for(int j=0;j<weigh;j++) {
			a[i*weigh+j]=bimg.brimage[i][j];
		}		
	}
	return a;		
}
}

 图像灰度二值化处理:

package imageprocess;

import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.IOException;
 
public class trainbrainimage {
	
	private int gray[][]=null;//存储图像灰度值
	public int brimage[][]=null;//存储图像二值化后灰度值
	public BufferedImage image;
	public void brmatrix(BufferedImage bi) throws IOException {
		
		int h=bi.getHeight();//获取图像的高
		int w=bi.getWidth();//获取图像的宽
		gray=new int[w][h];
		brimage=new int[w][h];
		for (int x = 0; x < w; x++) {
			for (int y = 0; y < h; y++) {
				gray[x][y]=getGray(bi.getRGB(x, y));
			}
		}		
		BufferedImage nbi=new BufferedImage(w,h,BufferedImage.TYPE_BYTE_BINARY);
		int SW=125;
			for (int x = 0; x < w; x++) {
				for (int y = 0; y<h; y++) {
					if(getAverageColor(gray, x, y, w, h)>SW){
						int max=new Color(255,255,255).getRGB();
						nbi.setRGB(x, y, max);
						brimage[x][y]=1;
					}else{
						int min=new Color(0,0,0).getRGB();
						nbi.setRGB(x, y, min);
						brimage[x][y]=0;
					}
				}
		}
		this.image=nbi;
		System.gc();
	}
	private int getGray(int rgb){
		String str=Integer.toHexString(rgb);
		int r=Integer.parseInt(str.substring(2,4),16);
		int g=Integer.parseInt(str.substring(4,6),16);
		int b=Integer.parseInt(str.substring(6,8),16);
		Color c=new Color(rgb);
		r=c.getRed();
	    	g=c.getGreen();
		b=c.getBlue();
		int top=(r+g+b)/3;
		return (int)(top);
	}
	private int  getAverageColor(int[][] gray, int x, int y, int w, int h)
    {
        int rs = gray[x][y]
              	+ (x == 0 ? 0 : gray[x - 1][y])
	            + (x == 0 || y == 0 ? 0 : gray[x - 1][y - 1])
	            + (x == 0 || y == h - 1 ? 0 : gray[x - 1][y + 1])
	            + (y == 0 ? 0 : gray[x][y - 1])
	            + (y == h - 1 ? 0 : gray[x][y + 1])
	            + (x == w - 1 ? 0 : gray[x + 1][ y])
	            + (x == w - 1 || y == 0 ? 20 : gray[x + 1][y - 1])
	            + (x == w - 1 || y == h - 1 ? 0 : gray[x + 1][y + 1]);
      return rs / 9;
    }
}

图像数据和数字是每4000张一个类(1-4000数字0,4001-8000数字1),数据集中所以要打乱样本训练,网络结构代码

package network;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Scanner;

public class NetWork {
	private int  iteration;    //迭代次数  
	public static int[][] allInput;//记录所有训练样本的所有矩阵//测试集
	private double stepsize;   //移动步长学习率
	private double weighRange; //用于规范初始化权值
	private double momentum;   //动量调节因子
	private int inputsize=784;   //输入点值 
	private int hinddensize=50;//隐层节点值
	private int outputsize=10;  //输出节点个数
	private int[] inputnode;  
	private node[] hiddennode;
	private node[] outputnode;
	//权值大小及更新时所用
	private double [] hinddenDelta;
	private double [] outputDelta;
	private double [][]inputweight;
	private double [][]oldInputeight;
	private double [][]outputweight;
	private double [][]oldoutputweight;
	private int success=0;
	private double error=0;
	private double errorrate;
	private double successrate;
	private char []type;  //保存测试数据的输出类型
	//初始化构造
	public NetWork(int  iteration,double stepsize,int hinddensize,double weighRange,double momentum) {
		this.iteration=iteration;
		this.stepsize=stepsize;
		this.hinddensize=hinddensize;
		this.weighRange=weighRange;
		this.momentum=momentum;
		//this.type=new char[test];
		inputnode=new int[inputsize];
		hiddennode=new node[hinddensize];
		outputnode=new node[outputsize];
		hinddenDelta=new  double [hinddensize];
		outputDelta=new double[outputsize];
		inputweight=new double[inputsize][hinddensize];
		oldInputeight=new double[inputsize][hinddensize];
		outputweight=new double[hinddensize][outputsize];
		oldoutputweight=new double[hinddensize][outputsize];

	}
	public double getsuccess() {
		return successrate;
	}
	public double geterror() {
		return errorrate;
	}
	public char[] gettype() {
		return type;
	}
	public void initNetwork() {
		initNodes();  //初始化节点参数
		initWeights(weighRange);//初始化权值
	}
	private void initWeights(double weighRange2) {
		// TODO Auto-generated method stub
		for(int i=0;i<inputsize;i++) {
			for(int j=0;j<hinddensize;j++) {
				inputweight[i][j]=randomBais()*weighRange;
				oldInputeight[i][j]=inputweight[i][j];
			}
		}
		for(int i=0;i<hinddensize;i++) {
			for(int j=0;j<outputsize;j++) {
				outputweight[i][j]=randomBais()*weighRange;
				oldoutputweight[i][j]=outputweight[i][j];
			}
		}
	}
	private double randomBais() {
		// TODO Auto-generated method stub
		return Math.random()-0.5;
	}
	public void initNodes() {
		// TODO Auto-generated method stub
		for(int i=0;i<hinddensize;i++) {
			hiddennode[i]=new node(Math.random(), randomBais());
		}
		for(int i=0; i<outputsize; i++) {
			outputnode[i] = new node(Math.random(), randomBais());
		}
	}

	public void train() throws IOException {//训练
		ArrayList< Integer>arrayList=new ArrayList<>();

		for(int i=0;i<40000;i++)
			arrayList.add(i);
		Collections.shuffle(arrayList);
		NetWork.getAllInput();
		while(iteration>0) {
			for(int i=0;i<arrayList.size();i++) {

				double[] desiredOutput= {0,0,0,0,0,0,0,0,0,0};
				inputnode=allInput[arrayList.get(i)];
				int m=arrayList.get(i);
				desiredOutput[m/4000] = 1;
				Forward();
				backPropagation(desiredOutput);
				if(iteration==1) {
					double b[]=new double[10];
					for(int p=0;p<10;p++)		
					{
					 b[p]= outputnode[p].getActivation();
					}
					int x=Max(b);
					error=+Math.pow((b[x]-m/4000),2);
				if(x==m/4000) {
					success++;
				}
				}
			}
			iteration--;
			 errorrate = error/40000;
			successrate=success/40000.0;
		}
		
		FileWriter f = new FileWriter(new File("./height.txt"));
		FileWriter f1 = new FileWriter(new File("./outweigh.txt"));
		//f.write(LEARNING_RATE + " " + inputSize + " " + hiddenSize + " " + outputSize + "\n");
		for (int m = 0; m < inputsize; m++)
			for (int j = 0; j < hinddensize; j++)
				f.write(String.format("%f\n", inputweight[m][j]));
		for (int j = 0; j < hinddensize; j++)
			for (int k = 0; k < outputsize; k++)
				f1.write(String.format("%f\n", outputweight[j][k]));
		f.close();
		f1.close();
		FileWriter f2 = new FileWriter(new File("./heightbais.txt"));
		FileWriter f3 = new FileWriter(new File("./outweighbais.txt"));
		for (int j = 0; j < hinddensize; j++)
			f2.write(String.format("%f\n", hiddennode[j].getBias()));

		for (int k = 0; k < outputsize; k++)
			f3.write(String.format("%f\n", outputnode[k].getBias()));
		f.close();
		f1.close();
		f2.close();
		f3.close();
	}
	private void backPropagation(double[] desiredOutput) {//反向传播
		// TODO Auto-generated method stub
		double temp;
		// calculate error
		for(int i=0; i<outputsize; i++)
			outputDelta[i] = (outputnode[i].getActivation() * (1 - outputnode[i].getActivation())) *
			(desiredOutput[i] - outputnode[i].getActivation());	    
		for(int i=0; i<hinddensize; i++) {
			temp = 0.0;
			for(int j=0; j<outputsize; j++) 
				temp += outputDelta[j] * outputweight[i][j];
			hinddenDelta[i] = hiddennode[i].getActivation() * (1.0 - hiddennode[i].getActivation()) * temp;
		}       
		// 更新权值
		for(int i=0; i<inputsize; i++) {
			for(int j=0; j<hinddensize; j++) {
				temp = inputweight[i][j] + (stepsize * hinddenDelta[j] * inputnode[i]) +
						(momentum * (inputweight[i][j] - oldInputeight[i][j]));
				oldInputeight[i][j] = inputweight[i][j];
				inputweight[i][j] = temp;
			}
		}

		for(int i=0; i<hinddensize; i++) {
			for(int j=0; j<outputsize; j++) {
				temp = outputweight[i][j] + (stepsize * outputDelta[j] * hiddennode[i].getActivation()) +
						(momentum * (outputweight[i][j] - oldoutputweight[i][j]));
				oldoutputweight[i][j] = outputweight[i][j];
				outputweight[i][j] = temp;
			}
		}

		// 更新bais
		for(int i=0; i<hinddensize; i++) {
			temp = hiddennode[i].getBias() + (stepsize * hinddenDelta[i]) +
					(momentum * (hiddennode[i].getBias() - hiddennode[i].getOldbais()));
			hiddennode[i].setOldbais(hiddennode[i].getBias());
			hiddennode[i].setBias(temp);
		}

		for(int i=0; i<outputsize; i++) {
			temp = outputnode[i].getBias() + (stepsize * outputDelta[i]) +
					(momentum * (outputnode[i].getBias() - outputnode[i].getOldbais()));
			outputnode[i].setOldbais(outputnode[i].getBias());
			outputnode[i].setBias(temp);
		}
	}
	private void Forward() {//前向传播
		// TODO Auto-generated method stub
		double temp;
		for(int i=0;i<hinddensize;i++) {
			temp=0.0;
			for(int j=0;j<inputsize;j++) 
				temp+=inputnode[j]*inputweight[j][i];
			try {
				hiddennode[i].setActivation(sigmoid(temp+hiddennode[i].getBias()));
			} catch (NullPointerException e) {
				// TODO: handle exception
				System.out.println(e);
			}

		}
		for(int i=0;i<outputsize;i++) {
			temp=0.0;
			for(int j=0;j<hinddensize;j++)
				try {
					temp+=hiddennode[j].getActivation()*outputweight[j][i];
				} catch (NullPointerException e) {
					// TODO: handle exception
					System.out.println(e);
				}
			try {
				outputnode[i].setActivation(sigmoid(temp+outputnode[i].getBias()));
			} catch (NullPointerException e) {
				// TODO: handle exception
				System.out.println(e);
			}

		}
	}
	private double sigmoid(double d) {//激活函数
		// TODO Auto-generated method stub
		return 1/(1 + Math.exp(-1 * d));
	}
	public static void getAllInput() throws IOException{
		NetWork.allInput=new int[40000][784];
		FileReader fo=new FileReader("./allInput.txt");
		BufferedReader bwo=new BufferedReader(fo);
		for(int i=0;i<40000;i++){
			String s[]=bwo.readLine().split(",");
			for(int j=0;j<784;j++){
				NetWork.allInput[i][j]=Integer.valueOf(s[j]);
			}
		}
		bwo.close();
		fo.close();
	}
	private  int Max(double[] b) {
		int maxIndex = 0;	//获取到的最大值的角标
		for(int i=0; i<b.length; i++){
			if(b[i] > b[maxIndex]){
				maxIndex = i;
			}
		}
		return maxIndex;
	}

	public  String recognize(int[] a) throws FileNotFoundException  {
		// TODO Auto-generated method stub
		for (int i = 0; i < a.length; i++) {
			inputnode[i]=a[i];		
		}
		
		File file = new File("./height.txt");
		Scanner in = new Scanner(file);
		for(int i=0;i<inputweight.length;i++) {
			for (int j=0;j<inputweight[0].length;j++) {
				inputweight[i][j]=in.nextDouble();          
			}
		}
		in.close();	
		File file2 = new File("./heightbais.txt");
		Scanner in2 = new Scanner(file2);
		for (int j=0;j<hinddensize;j++) {
			{
				hiddennode[j].setBias(in2.nextDouble());	
			}
		}
		//			
		File file3 = new File("./outweighbais.txt");
		Scanner in3 = new Scanner(file3);
		try {
			for (int j=0;j<10;j++) {

				outputnode[j].setBias(in3.nextDouble());
			}
		} catch (NullPointerException e) {
			System.err.println(e);
		}

		File file4 = new File("./outweigh.txt");
		Scanner in4 = new Scanner(file4);

		for(int i=0;i<hinddensize;i++)
			for (int j=0;j<outputsize;j++) {
				{outputweight[i][j]=in4.nextDouble();
				}
			}

		Forward();
		double b[]=new double[10];
		for(int p=0;p<10;p++)		
		{
			b[p]=  outputnode[p].getActivation();
		}
		int index=Max(b);
		return "数字是 :"+index;
	}
}

节点:

package network;

public class node {
private double activation;
private double bias;
private double oldbais;
public node(double a,double b) {
	this.activation=a;
	this.bias=b;
}
public double getActivation() {
	return this.activation;
}
public void setActivation(double activation) {
	this.activation = activation;
}
public double getBias() {
	return this.bias;
}
public void setBias(double bias) {
	this.bias = bias;
}
public double getOldbais() {
	return this.oldbais;
}
public void setOldbais(double oldbais) {
	this.oldbais = oldbais;
}
public String toString() {
	return this.activation + " " + this.bias;
}
}

绘制待识别数字界面:

package main;
import java.awt.*; 
import java.awt.event.*; 
import java.util.Vector;
import javax.swing.JPanel;

public class mypanel extends JPanel {

	private static final long serialVersionUID = 1L;
	private Vector<Vector<Point>> FreedomDatas = new Vector<Vector<Point>>(); 
	private Color lineColor = Color.white; 
    private int lineWidth = 16; 
    
    public mypanel() 
    { //setBorder(BorderFactory.createLineBorder(Color.BLACK));
        addMouseListener(new MouseAdapter() 
        { 
            public void mousePressed(MouseEvent e) 
            { 
                        Point p = new Point(e.getX(),e.getY()); 
                        Vector<Point> newLine = new Vector<Point>(); 
                        newLine.add(p); 
                        FreedomDatas.add(newLine);  
            }
        
        public void mouseReleased(MouseEvent e) 
        { 
        	repaint();
        }
        });
        
        addMouseMotionListener(new MouseMotionAdapter() 
        { 
            public void mouseDragged(MouseEvent e) 
            {  
                    Point p = new Point(e.getX(),e.getY()); 
                    int n = FreedomDatas.size()-1; //拿到最后一条线的位置 
                    Vector<Point> lastLine = FreedomDatas.get(n); 
                    lastLine.add(p); 
            }
        });
    }
    
    public void cleanAll() 
    { 
        FreedomDatas.clear(); 
        repaint(); 
    } 
    
    public void paint(Graphics g) 
    { 
    	g.fillRect(0, 0, getWidth(), getHeight());
        g.setColor(lineColor); 
         
        Graphics2D g_2D = (Graphics2D)g; 
        BasicStroke stroke = new BasicStroke(lineWidth,BasicStroke.CAP_ROUND,BasicStroke.JOIN_ROUND); 
        g_2D.setStroke(stroke); 
        
        Vector<Point> v; 
        Point s,e; 
        int i,j,m; 
        int n = FreedomDatas.size(); 
        for(i=0;i<n;i++) 
        { 
            v = FreedomDatas.get(i); 
            m = v.size()-1; 
            for(j=0;j<m;j++) 
            { 
                s = (Point)v.get(j); 
                e = (Point)v.get(j+1); 
                g.drawLine(s.x, s.y, e.x, e.y); 
            }
        }
    }
}

抓取绘图界面的数字图像并处理二值化

package imageprocess;

import java.awt.Color;
import java.awt.Graphics;
import java.awt.GraphicsConfiguration;
import java.awt.GraphicsDevice;
import java.awt.GraphicsEnvironment;
import java.awt.HeadlessException;
import java.awt.Image;
import java.awt.Toolkit;
import java.awt.Transparency;
import java.awt.geom.AffineTransform;
import java.awt.image.AffineTransformOp;
import java.awt.image.BufferedImage;
import java.awt.image.CropImageFilter;
import java.awt.image.FilteredImageSource;
import java.awt.image.ImageFilter;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;

public class getimage {
	public final static int[][] getMatirx(BufferedImage bi) throws IOException{
		binaryimage biimg=new binaryimage();//创建二值图类,分别保存二值化后的图片及矩阵	
		int h=bi.getHeight();
		int w=bi.getWidth();
		if(h>800||w>800){
			bi=scale(bi, 800, 800);
			h=bi.getHeight();
			w=bi.getWidth();
		}
		biimg.brmatrix(bi);//二值图保;

		int bi_matrix[][]=biimg.brimage;//二值矩阵
		int left=0,right=0,top=0,below=0;
		int row[]=new int[w];
		for(int i=0;i<w;i++){
			int s=0;
			for(int j=0;j<h;j++){
				s=s+bi_matrix[i][j];
			}
			row[i]=s;
		}	
		int line[]=new int[h];
		for(int j=0;j<h;j++){
			int s=0;
			for(int i=0;i<w;i++){
				s=s+bi_matrix[i][j];
			}
			line[j]=s;
		}	
		for(int i=1;i<w;i++){
			if(row[i]>=2){
				if(left==0){
					left=i;
				}
				if(right<i){
					right=i;}
			}
		}		
		for(int i=1;i<h;i++){
			if(line[i]>=2){
				if(top==0){
					top=i;
				}
				if(below<i){
					below=i;}
			}
		}
		int new_h;int new_w;
		new_h=(28-(below-top)%28)-top+below;
		new_w=(28-(right-left)%28)-left+right;

		if((new_h/new_w)>=2){
			new_w=(new_h/2%28)+new_h/2;
			top=top-(28-(below-top)%28)/2;
			left=left-(new_w-right+left)/2;
		}
		else{
			top=top-(28-(below-top)%28)/2;
			left=left-(28-(right-left)%28)/2;
		}	 
		biimg.image=cut(biimg.image, left, top, w, h, new_w, new_h);
		int InputMatrix[][]=new int[28][28];	
		InputMatrix=cut2(biimg.image, 28, 28);
		
		return InputMatrix;
	}
	public final static BufferedImage scale(BufferedImage bi, int height, int width) {
		double ratio = 0.0; //  缩放比例
		Image temp = bi.getScaledInstance(width, height, Image.SCALE_SMOOTH);
		if ((bi.getHeight() > height) || (bi.getWidth() > width)) {
			if (bi.getHeight() > bi.getWidth()) {
				ratio = (new Integer(height)).doubleValue()/ bi.getHeight();
			} else {
				ratio = (new Integer(width)).doubleValue()/ bi.getWidth();
			}
			AffineTransformOp op = new AffineTransformOp(AffineTransform
					.getScaleInstance(ratio, ratio), null);
			temp = op.filter(bi, null);
		}
		return toBufferedImage(temp);
	}
	public final static BufferedImage toBufferedImage(Image image) {  
		if (image instanceof BufferedImage) {  
			return (BufferedImage) image;  
		}  
		image = new ImageIcon(image).getImage();  
		boolean hasAlpha = false;  
		BufferedImage bimage = null;  
		GraphicsEnvironment ge = GraphicsEnvironment  
				.getLocalGraphicsEnvironment();  
		try {  
			int transparency = Transparency.OPAQUE;  
			if (hasAlpha) {  
				transparency = Transparency.BITMASK;  
			}  
			GraphicsDevice gs = ge.getDefaultScreenDevice();  
			GraphicsConfiguration gc = gs.getDefaultConfiguration();  
			bimage = gc.createCompatibleImage(image.getWidth(null),  
					image.getHeight(null), transparency);  
		} catch (HeadlessException e) {  
		}  
		if (bimage == null) {  
			int type = BufferedImage.TYPE_INT_RGB;  
			if (hasAlpha) {  
				type = BufferedImage.TYPE_INT_ARGB;  
			}  
			bimage = new BufferedImage(image.getWidth(null),  
					image.getHeight(null), type);  
		}  
		Graphics g = bimage.createGraphics();  
		g.drawImage(image, 0, 0, null);  
		g.dispose();  
		return bimage;  
	}  
	public final static BufferedImage cut(BufferedImage bi,int x, int y,int w,int h,int new_w,int new_h) throws IOException {
		Image image = bi.getScaledInstance(w, h,
				Image.SCALE_DEFAULT);
		ImageFilter cropFilter = new CropImageFilter(x, y, new_w, new_h);
		Image img = Toolkit.getDefaultToolkit().createImage(
				new FilteredImageSource(image.getSource(),
						cropFilter));
		BufferedImage tag = new BufferedImage(new_w, new_h, BufferedImage.TYPE_INT_RGB);
		Graphics g = tag.getGraphics();
		g.drawImage(img, 0, 0, new_w, new_h, null); //  绘制切割后的图
		g.dispose();       
		ImageIO.write(tag, "jpg", new File("./after.jpg"));       
		return tag;       
	}
	public final static int[][] cut2( BufferedImage bi,int rows, int cols) {
		int InputMatrix[][]=new int[rows][cols];	       
		try {
			int srcWidth = bi.getHeight(); 
			int srcHeight = bi.getWidth(); 	          
			if (srcWidth > 0 && srcHeight > 0) {
				Image img;
				ImageFilter cropFilter;
				Image image = bi.getScaledInstance(srcWidth, srcHeight, Image.SCALE_DEFAULT);
				int destWidth = srcWidth; 
				int destHeight = srcHeight; 
				if (srcWidth % cols == 0) {
					destWidth = srcWidth / cols;
				} else {
					destWidth = (int) Math.floor(srcWidth / cols) + 1;
				}
				if (srcHeight % rows == 0) {
					destHeight = srcHeight / rows;
				} else {
					destHeight = (int) Math.floor(srcWidth / rows) + 1;
				}
				for (int i = 0; i < rows; i++) {
					for (int j = 0; j < cols; j++) {
						cropFilter = new CropImageFilter(j * destWidth, i * destHeight,
								destWidth, destHeight);
						img = Toolkit.getDefaultToolkit().createImage(
								new FilteredImageSource(image.getSource(),
										cropFilter));
						BufferedImage tag = new BufferedImage(destWidth,
								destHeight, BufferedImage.TYPE_INT_RGB);
						Graphics g = tag.getGraphics();
						g.drawImage(img, 0, 0, null); //  绘制缩小后的图						
						g.dispose();  
						if(IsBlank(tag,destWidth,destHeight)==true){
							InputMatrix[i][j]=1;
						}
						else{
							InputMatrix[i][j]=0;
						}
					}
				}
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
		return InputMatrix;
	}
	public final static boolean IsBlank(BufferedImage tag,int destWidth,int destHeight){
		boolean blank=true;
		int gray[][]=new int[destWidth][destHeight];
		for (int x = 0; x < destWidth; x++) {
			for (int y = 0; y < destHeight; y++) {
				gray[x][y]=getGray(tag.getRGB(x, y));
			}
		}
		for(int i=0;i<destWidth&&blank==true;i++){
			for(int j=0;j<destHeight&&blank==true;j++){
				if(gray[i][j]==0){
					blank=false;
				}
			}
		}
		return blank;
	}
	public final static int getGray(int rgb){
		String str=Integer.toHexString(rgb);
		int r=Integer.parseInt(str.substring(2,4),16);
		int g=Integer.parseInt(str.substring(4,6),16);
		int b=Integer.parseInt(str.substring(6,8),16);
		Color c=new Color(rgb);
		r=c.getRed();
		g=c.getGreen();
		b=c.getBlue();
		int top=(r+g+b)/3;
		return (int)(top);
	}
}
package imageprocess;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
 
import javax.imageio.ImageIO;
 
public class binaryimage {
	private int gray[][]=null;//存储图像灰度值
	public int brimage[][]=null;//存储图像二值化后灰度值
	private	int gra[][]=null;//给图像添白框,方便去噪
	public BufferedImage image;
	public void brmatrix(BufferedImage bi) throws IOException {
		
		int h=bi.getHeight();//获取图像的高
		int w=bi.getWidth();//获取图像的宽
		gray=new int[w][h];
		brimage=new int[w][h];
		for (int x = 0; x < w; x++) {
			for (int y = 0; y < h; y++) {
				gray[x][y]=getGray(bi.getRGB(x, y));
			}
		}	
	Brighter(gray,w,h);	
		gra=new int[w+4][h+4];
		for(int i=0;i<w+4;i++){
			for(int j=0; j<h+4; j++){
				if(i>1&&i<w-1&&j>1&&j<h-1){
					gra[i][j]=gray[i-2][j-2];
				}
				else{
					gra[i][j]=0;}
			}
		}
		BufferedImage nbi=new BufferedImage(w,h,BufferedImage.TYPE_BYTE_BINARY);
		int SW=125;
			for (int x = 0; x < w; x++) {
				for (int y = 0; y<h; y++) {
					if(getAverageColor(gra, x, y, w, h)>SW){
						int max=new Color(255,255,255).getRGB();
						nbi.setRGB(x, y, max);
						brimage[x][y]=1;
					}else{
						int min=new Color(0,0,0).getRGB();
						nbi.setRGB(x, y, min);
						brimage[x][y]=0;
					}
				}
		}
		this.image=nbi;
		System.gc();
	}
	private int getGray(int rgb){
		String str=Integer.toHexString(rgb);
		int r=Integer.parseInt(str.substring(2,4),16);
		int g=Integer.parseInt(str.substring(4,6),16);
		int b=Integer.parseInt(str.substring(6,8),16);
		Color c=new Color(rgb);
		r=c.getRed();
	    	g=c.getGreen();
		b=c.getBlue();
		int top=(r+g+b)/3;
		return (int)(top);
	}
	
	private int  getAverageColor(int[][] gray, int x, int y, int w, int h)
    {
		 int rs=0;					
			for(int i=0;i<5;i++){
				for(int j=0;j<5;j++){
					rs=gray[x+i][y+j]+rs;}
			}
	        return rs / 25;
    }
	public static void Brighter(int[][]gray,int w,int h){
		for(int x=0;x<w;x++){
			for(int y=0;y<h;y++){
				gray[x][y]=(int) Math.floor(gray[x][y]*1.25);
				if(gray[x][y]>255){
					gray[x][y]=255;
				}
			}
		}
	}
}

训练界面:设定迭代次数,移动步长,隐层数量,动量调节数值

package main;
import java.awt.Color;
import java.awt.Cursor;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowEvent;
import java.awt.event.WindowListener;
import java.io.File;
import java.io.FileWriter;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JTextField;
import network.NetWork;

public class trainwin extends JFrame implements WindowListener {

	private static final long serialVersionUID = 1L;
	private JTextField Limit;
	private JTextField LearningRate;
	private JTextField hLNeurons;
	private JButton txttrain;
	private JLabel lerror;
	private JLabel lters;
	private JLabel lsuccess;
	private JLabel MSE;
	private JLabel emnue;
	private JTextField emnuefile;
	private NetWork nn;

	public trainwin() {
		setTitle("自我训练");
		setSize(new Dimension(400, 440));
		setLocationRelativeTo(null);
		setDefaultCloseOperation(EXIT_ON_CLOSE);
		setResizable(false);
		setLayout(null);
		createRightSide();
		addWindowListener(this);
		setVisible(true);
	}

	private void createRightSide() {
		JLabel lblIterLimit = new JLabel("迭代次数");
		lblIterLimit.setBounds(70, 10, 100, 50);
		JLabel lblLearningRate = new JLabel("学习率");
		lblLearningRate.setBounds(70, 40, 100, 50);
		JLabel lblHLNeurons = new JLabel("隐层数量");
		lblHLNeurons.setBounds(70, 70, 100, 50);
		emnue = new JLabel("动能调节");
		emnue.setBounds(70, 100, 100, 50);
		emnue.setCursor(new Cursor(Cursor.TEXT_CURSOR));
		JLabel lblItersTxt = new JLabel("迭代结束:");
		lblItersTxt.setBounds(70, 230, 150, 50);
		JLabel lblSuccessTxt = new JLabel("成功率:");
		lblSuccessTxt.setBounds(70, 270, 100, 50);
		JLabel lblMSETxt = new JLabel("均方误差:");
		lblMSETxt.setBounds(70, 310, 150, 50);

		Limit = new JTextField();
		Limit.setBounds(165, 25, 100, 20);
		Limit.setCursor(new Cursor(Cursor.TEXT_CURSOR));
		LearningRate = new JTextField();
		LearningRate.setBounds(165, 55, 100, 20);
		LearningRate.setCursor(new Cursor(Cursor.TEXT_CURSOR));
		hLNeurons = new JTextField();
		hLNeurons.setBounds(165, 85, 100, 20);
		hLNeurons.setCursor(new Cursor(Cursor.TEXT_CURSOR));
		emnuefile=new JTextField();
		emnuefile.setBounds(165, 115, 100, 20);
		emnuefile.setCursor(new Cursor(Cursor.TEXT_CURSOR));

		txttrain = new JButton("训练开始");
		txttrain.setBounds(100, 150, 100, 30);
		txttrain.setFocusPainted(false);
		txttrain.setCursor(new Cursor(Cursor.HAND_CURSOR));
		txttrain.addActionListener(new TrainListener());
		
		lerror = new JLabel("");
		lerror.setBounds(70, 190, 300, 30);
		lerror.setForeground(Color.RED);
		
		lters = new JLabel("");
		lters.setBounds(160, 230, 100, 50);
		lters.setForeground(Color.BLUE);
		lsuccess = new JLabel("");
		lsuccess.setBounds(160, 270, 100, 50);
		lsuccess.setForeground(Color.BLUE);
		MSE = new JLabel("");
		MSE.setBounds(160, 310, 100, 50);
		MSE.setForeground(Color.BLUE);

		getContentPane().add(lblIterLimit);
		getContentPane().add(lblLearningRate);
		getContentPane().add(lblHLNeurons);
		getContentPane().add(Limit);
		getContentPane().add(LearningRate);
		getContentPane().add(hLNeurons);
		getContentPane().add(txttrain);
		getContentPane().add(lerror);
		getContentPane().add(lblItersTxt);
		getContentPane().add(lblSuccessTxt);
		getContentPane().add(lblMSETxt);
		getContentPane().add(lters);
		getContentPane().add(lsuccess);
		getContentPane().add(MSE);
		getContentPane().add(emnue);
		getContentPane().add(emnuefile);
	}

	private class TrainListener implements ActionListener {

		public void actionPerformed(ActionEvent e) {
			new Thread() {
				public void run() {
					txttrain.setEnabled(false);
					txttrain.setText("训练中...");
					lters.setText("");
					lsuccess.setText("");
					MSE.setText("");
					if (!Limit.getText().matches("[1-9][0-9]*"))
						lerror.setText("请输入迭代次数.");
					else if (!LearningRate.getText().matches("[0-9]*\\.[0-9]+"))
						lerror.setText("请输入可用的学习率.");
					else if (!hLNeurons.getText().matches("[1-9][0-9]*"))
						lerror.setText("请输入有隐层神经元数量 .");
					else if (!emnuefile.getText().matches("[0-9]*\\.[0-9]+"))
						lerror.setText("请输入有效动量调节值 .");
					else {
						int iter=Integer.parseInt(Limit.getText());
						double rate = Double.parseDouble(LearningRate.getText());
						int  hide= Integer.parseInt(hLNeurons.getText());
						double  emnue= Double.parseDouble(emnuefile.getText());
													
						nn = new NetWork(iter, rate, hide,0.05,emnue);
						try {
						nn.initNetwork();
						nn.train();				
						lters.setText(String.valueOf(iter));
						lsuccess.setText(String.format("%f", nn.getsuccess() * 100));
						MSE.setText( String.valueOf(nn.geterror()));
						FileWriter pf=new FileWriter(new File("./parameter.txt"));
						pf.write(String.format("%d\n",iter));
						pf.write(String.format("%f\n",rate));
						pf.write(String.format("%d\n",hide ));
						pf.write(String.format("%f\n",emnue));
						pf.close();
						} catch (Exception e1) { e1.printStackTrace(); }
					}
					txttrain.setEnabled(true);
					txttrain.setText("训练");
				}
			}.start();
		}

	}	
	public void windowClosed(WindowEvent e) {}
	public void windowOpened(WindowEvent e) {}
	public void windowIconified(WindowEvent e) {}
	public void windowDeiconified(WindowEvent e) {}
	public void windowActivated(WindowEvent e) {}
	public void windowDeactivated(WindowEvent e) {}
	public static void main(String[] args) {
		new trainwin();
	}
	@Override
	public void windowClosing(WindowEvent windowevent) {
		// TODO Auto-generated method stub	
	}
}

训练数据及训练10次,学习率0.3,隐层50和动量0.5的个参数数据见: