多层感知机的手写数字识别,迭代10次对训练集的正确率97
Main函数,在绘制完数字后,要点下确定按钮再去识别,重绘按钮自然是再次绘图
训练自己的网络结构会替换之前训练的网络结构,没有写保存或者另存新网络模型。结果对训练集变现很好,对绘图的识别结果仍不是很理想。
package main;
import java.awt.Color;
import java.awt.Container;
import java.awt.Graphics;
import java.awt.Image;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Scanner;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JButton;
import javax.swing.JFileChooser;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JTextField;
import javax.swing.JTextPane;
import javax.swing.text.Style;
import javax.swing.text.StyleConstants;
import javax.swing.text.StyleContext;
import imageprocess.getimage;
import network.NetWork;
import network.traindata;
public class GUI extends JFrame {
private JFrame jFrame;
private BufferedImage img;//用于显示输入图片
private JButton sure;//手写输入确定
private JButton cancel;//手写输入确定
private JButton recognition;//识别
private JButton train;//训练自己的网络
private JButton open;
private JTextField result;
private int[][] getmatrix=new int[28][28];
private JTextPane imgtextarea;
private JLabel imglabel;
private static NetWork neunet;
private JFileChooser choose;//选择文件
//private int k=0;
public GUI() {
neunet=new NetWork(10,0.01,50,0.2,0.5);
neunet.initNodes();
jFrame=new JFrame("数字识别");
jFrame.setBounds(0, 0, 765, 800);
jFrame.setLayout(null);
recognition=new JButton("识别结果");
train=new JButton("训练");
open=new JButton("打开图片");
sure=new JButton("确定");
cancel=new JButton("重绘");
JPanel resultpanel = new JPanel();
final mypanel panel = new mypanel();//新建画板
Container contentPane = getContentPane();
contentPane.setBounds(0, 0,350,350);
contentPane.add(panel);
jFrame.add(contentPane);
JPanel draw=new JPanel();//画板桌布
draw.setBounds(0, 0, 380,420);
draw.setLayout(null);
draw.setBackground(Color.lightGray);
jFrame.add(draw);
draw.add(sure);
sure.setBounds(10, 370, 60, 30);
draw.add(cancel);
cancel.setBounds(130, 370, 60, 30);
draw.add(train);
train.setBounds(250, 370, 60, 30);
open.setBounds(420, 320, 90, 30);
recognition.setBounds(560, 320, 90, 30);
imgtextarea=new JTextPane();
Style style=new StyleContext().new NamedStyle();
StyleConstants.setLineSpacing(style,-0.1f);
StyleConstants.setFontSize(style, 7);
StyleConstants.setBold(style, true);
imgtextarea.setLogicalStyle(style);
imglabel=new JLabel();
imgtextarea.setSize(50, 70);
imgtextarea.setBounds(300, 0, 200, 60);
imglabel.setBounds(200,10,100,100);
imgtextarea.setEditable(false);
choose = new JFileChooser();
choose.setCurrentDirectory(new File("."));
resultpanel.add(imglabel);
resultpanel.add(imgtextarea);
result=new JTextField();
result.setBounds(560, 360, 140, 50);
result.setVisible(true);
resultpanel.setBounds(381, 0, 350, 310);
resultpanel.setBackground(Color.gray);
jFrame.add(resultpanel);
jFrame.add(result);
jFrame.add(recognition);
jFrame.add(open);
jFrame.setSize(761, 450);
jFrame.setVisible(true);
sure.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent actionevent) {
// TODO Auto-generated method stub
BufferedImage image=new BufferedImage(panel.getWidth(), panel.getHeight(), BufferedImage.TYPE_INT_RGB);
Graphics gs=image.getGraphics();
panel.paintAll(gs);
gs.drawImage(image, 0, 0, panel.getWidth(), panel.getHeight(), null);
try {
ImageIO.write(image, "png", new File("./save.jpg"));
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
try {
getmatrix=getimage.getMatirx(image);
if(image.getHeight()>28||image.getWidth()>28){
image=getimage.scale(image, 28, 28);}
imglabel.setIcon(new ImageIcon((Image)image));//把图片作为icon显示
for (int i = 0; i < getmatrix.length-1; i++) {
for (int j = i+1; j < getmatrix.length; j++) {
int temp=getmatrix[i][j];
getmatrix[i][j]=getmatrix[j][i];
getmatrix[j][i]=temp;
}
}
imgtextarea.setText("");
String s="";
for(int i=0;i<getmatrix[0].length;i++) {
for(int j=0;j<getmatrix.length;j++) {
if(j==getmatrix[0].length-1) {
s=s+getmatrix[j][i]+"\n";
imgtextarea.setText(s);
}
else {
s=s+getmatrix[j][i]+",";
imgtextarea.setText(s);
}
}
}
} catch (IOException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
}
});
recognition.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent actionevent) {
// TODO Auto-generated method stub
int input[]=new int [784];
for(int i=0;i<getmatrix[0].length;i++)
{
for(int j=0;j<getmatrix.length;j++) {
input[i*getmatrix.length+j]=getmatrix[i][j];
}
}
File file = new File("./parameter.txt");
Scanner in = null;
try {
in = new Scanner(file);
} catch (FileNotFoundException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
int iter = in.nextInt();
double size=in.nextDouble();
int inputSize = in.nextInt();
double emnue = in.nextDouble();
NetWork net=new NetWork(iter,size , inputSize, 0.05, emnue);
net.initNodes();
try {
result.setText(net.recognize(input));
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
});
cancel.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent actionevent) {
// TODO Auto-generated method stub
panel.cleanAll();
}
});
open.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent actionevent) {
// TODO Auto-generated method stub
int result = choose.showOpenDialog(null);
if(result == JFileChooser.APPROVE_OPTION){
String name = choose.getSelectedFile().getPath();
try {
img=ImageIO.read(new File(name));
if(img.getWidth()>28||img.getHeight()>28)
{img=getimage.scale(img, 28, 28);//缩放图片
}
imglabel.setIcon(new ImageIcon((Image)img));//把图片作为icon显示
int a[]=traindata.SampleMatirx(img);
for (int i = 0; i < a.length; i++) {
getmatrix[i/28][i%28]=a[i];
}
String s=" ";
for(int i=0;i<getmatrix[0].length;i++) {
for(int j=0;j<getmatrix.length;j++) {
if(j==getmatrix[0].length-1) {
s=s+getmatrix[j][i]+"\n";
imgtextarea.setText(s);
}
else {
s=s+getmatrix[j][i]+",";
imgtextarea.setText(s);
}
}
}
} catch (IOException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
}
}
});
train.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent actionevent) {
// TODO Auto-generated method stub
trainwin aTrainwin=new trainwin();
}
});
}
public static void main(String[] args) throws IOException {
//traindata td=new traindata();//用于提取训练集
//td.imagetomatrix();
GUI gui=new GUI();
}
}
获取数据集二值化矩阵保存用于训练,使用了40000张图像数据作为训练集
package network;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import javax.imageio.ImageIO;
import imageprocess.binaryimage;
import imageprocess.trainbrainimage;
public class traindata {
private static int sampleNumber=40000;
public static void imagetomatrix() throws IOException{
int input[][]=new int[sampleNumber][784];
try {
for(int i=0;i<=9;i++) {
for(int j=0;j<=3999;j++) {
BufferedImage image;
image=ImageIO.read(new File("./mnist/mnist_data/"+i+"."+j+".jpg"));
input[j+i*4000]=SampleMatirx(image);
}
}
File f=new File("./allInput.txt");
if(!f.exists()){
f.createNewFile();}
else{
FileOutputStream opf=new FileOutputStream("./allInput.txt");
PrintStream s=new PrintStream(opf);
for(int i=0;i<sampleNumber;i++){
for(int j=0;j<784;j++){
s.print(input[i][j]+",");
}
s.println();
}
}
}catch(Exception e)
{System.out.println(e);
}
}
public static int[] SampleMatirx(BufferedImage image) throws IOException {
// TODO Auto-generated method stub
trainbrainimage bimg=new trainbrainimage();//创建二值图类,分别保存二值化后的图片及矩阵
bimg.brmatrix(image);//进行二值化,把矩阵、二值图保存到类里;
int high=image.getHeight();
int weigh=image.getWidth();
int a[]=new int[high*weigh];
for (int i=0;i<high;i++) {
for(int j=0;j<weigh;j++) {
a[i*weigh+j]=bimg.brimage[i][j];
}
}
return a;
}
}
图像灰度二值化处理:
package imageprocess;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.IOException;
public class trainbrainimage {
private int gray[][]=null;//存储图像灰度值
public int brimage[][]=null;//存储图像二值化后灰度值
public BufferedImage image;
public void brmatrix(BufferedImage bi) throws IOException {
int h=bi.getHeight();//获取图像的高
int w=bi.getWidth();//获取图像的宽
gray=new int[w][h];
brimage=new int[w][h];
for (int x = 0; x < w; x++) {
for (int y = 0; y < h; y++) {
gray[x][y]=getGray(bi.getRGB(x, y));
}
}
BufferedImage nbi=new BufferedImage(w,h,BufferedImage.TYPE_BYTE_BINARY);
int SW=125;
for (int x = 0; x < w; x++) {
for (int y = 0; y<h; y++) {
if(getAverageColor(gray, x, y, w, h)>SW){
int max=new Color(255,255,255).getRGB();
nbi.setRGB(x, y, max);
brimage[x][y]=1;
}else{
int min=new Color(0,0,0).getRGB();
nbi.setRGB(x, y, min);
brimage[x][y]=0;
}
}
}
this.image=nbi;
System.gc();
}
private int getGray(int rgb){
String str=Integer.toHexString(rgb);
int r=Integer.parseInt(str.substring(2,4),16);
int g=Integer.parseInt(str.substring(4,6),16);
int b=Integer.parseInt(str.substring(6,8),16);
Color c=new Color(rgb);
r=c.getRed();
g=c.getGreen();
b=c.getBlue();
int top=(r+g+b)/3;
return (int)(top);
}
private int getAverageColor(int[][] gray, int x, int y, int w, int h)
{
int rs = gray[x][y]
+ (x == 0 ? 0 : gray[x - 1][y])
+ (x == 0 || y == 0 ? 0 : gray[x - 1][y - 1])
+ (x == 0 || y == h - 1 ? 0 : gray[x - 1][y + 1])
+ (y == 0 ? 0 : gray[x][y - 1])
+ (y == h - 1 ? 0 : gray[x][y + 1])
+ (x == w - 1 ? 0 : gray[x + 1][ y])
+ (x == w - 1 || y == 0 ? 20 : gray[x + 1][y - 1])
+ (x == w - 1 || y == h - 1 ? 0 : gray[x + 1][y + 1]);
return rs / 9;
}
}
图像数据和数字是每4000张一个类(1-4000数字0,4001-8000数字1),数据集中所以要打乱样本训练,网络结构代码
package network;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Scanner;
public class NetWork {
private int iteration; //迭代次数
public static int[][] allInput;//记录所有训练样本的所有矩阵//测试集
private double stepsize; //移动步长学习率
private double weighRange; //用于规范初始化权值
private double momentum; //动量调节因子
private int inputsize=784; //输入点值
private int hinddensize=50;//隐层节点值
private int outputsize=10; //输出节点个数
private int[] inputnode;
private node[] hiddennode;
private node[] outputnode;
//权值大小及更新时所用
private double [] hinddenDelta;
private double [] outputDelta;
private double [][]inputweight;
private double [][]oldInputeight;
private double [][]outputweight;
private double [][]oldoutputweight;
private int success=0;
private double error=0;
private double errorrate;
private double successrate;
private char []type; //保存测试数据的输出类型
//初始化构造
public NetWork(int iteration,double stepsize,int hinddensize,double weighRange,double momentum) {
this.iteration=iteration;
this.stepsize=stepsize;
this.hinddensize=hinddensize;
this.weighRange=weighRange;
this.momentum=momentum;
//this.type=new char[test];
inputnode=new int[inputsize];
hiddennode=new node[hinddensize];
outputnode=new node[outputsize];
hinddenDelta=new double [hinddensize];
outputDelta=new double[outputsize];
inputweight=new double[inputsize][hinddensize];
oldInputeight=new double[inputsize][hinddensize];
outputweight=new double[hinddensize][outputsize];
oldoutputweight=new double[hinddensize][outputsize];
}
public double getsuccess() {
return successrate;
}
public double geterror() {
return errorrate;
}
public char[] gettype() {
return type;
}
public void initNetwork() {
initNodes(); //初始化节点参数
initWeights(weighRange);//初始化权值
}
private void initWeights(double weighRange2) {
// TODO Auto-generated method stub
for(int i=0;i<inputsize;i++) {
for(int j=0;j<hinddensize;j++) {
inputweight[i][j]=randomBais()*weighRange;
oldInputeight[i][j]=inputweight[i][j];
}
}
for(int i=0;i<hinddensize;i++) {
for(int j=0;j<outputsize;j++) {
outputweight[i][j]=randomBais()*weighRange;
oldoutputweight[i][j]=outputweight[i][j];
}
}
}
private double randomBais() {
// TODO Auto-generated method stub
return Math.random()-0.5;
}
public void initNodes() {
// TODO Auto-generated method stub
for(int i=0;i<hinddensize;i++) {
hiddennode[i]=new node(Math.random(), randomBais());
}
for(int i=0; i<outputsize; i++) {
outputnode[i] = new node(Math.random(), randomBais());
}
}
public void train() throws IOException {//训练
ArrayList< Integer>arrayList=new ArrayList<>();
for(int i=0;i<40000;i++)
arrayList.add(i);
Collections.shuffle(arrayList);
NetWork.getAllInput();
while(iteration>0) {
for(int i=0;i<arrayList.size();i++) {
double[] desiredOutput= {0,0,0,0,0,0,0,0,0,0};
inputnode=allInput[arrayList.get(i)];
int m=arrayList.get(i);
desiredOutput[m/4000] = 1;
Forward();
backPropagation(desiredOutput);
if(iteration==1) {
double b[]=new double[10];
for(int p=0;p<10;p++)
{
b[p]= outputnode[p].getActivation();
}
int x=Max(b);
error=+Math.pow((b[x]-m/4000),2);
if(x==m/4000) {
success++;
}
}
}
iteration--;
errorrate = error/40000;
successrate=success/40000.0;
}
FileWriter f = new FileWriter(new File("./height.txt"));
FileWriter f1 = new FileWriter(new File("./outweigh.txt"));
//f.write(LEARNING_RATE + " " + inputSize + " " + hiddenSize + " " + outputSize + "\n");
for (int m = 0; m < inputsize; m++)
for (int j = 0; j < hinddensize; j++)
f.write(String.format("%f\n", inputweight[m][j]));
for (int j = 0; j < hinddensize; j++)
for (int k = 0; k < outputsize; k++)
f1.write(String.format("%f\n", outputweight[j][k]));
f.close();
f1.close();
FileWriter f2 = new FileWriter(new File("./heightbais.txt"));
FileWriter f3 = new FileWriter(new File("./outweighbais.txt"));
for (int j = 0; j < hinddensize; j++)
f2.write(String.format("%f\n", hiddennode[j].getBias()));
for (int k = 0; k < outputsize; k++)
f3.write(String.format("%f\n", outputnode[k].getBias()));
f.close();
f1.close();
f2.close();
f3.close();
}
private void backPropagation(double[] desiredOutput) {//反向传播
// TODO Auto-generated method stub
double temp;
// calculate error
for(int i=0; i<outputsize; i++)
outputDelta[i] = (outputnode[i].getActivation() * (1 - outputnode[i].getActivation())) *
(desiredOutput[i] - outputnode[i].getActivation());
for(int i=0; i<hinddensize; i++) {
temp = 0.0;
for(int j=0; j<outputsize; j++)
temp += outputDelta[j] * outputweight[i][j];
hinddenDelta[i] = hiddennode[i].getActivation() * (1.0 - hiddennode[i].getActivation()) * temp;
}
// 更新权值
for(int i=0; i<inputsize; i++) {
for(int j=0; j<hinddensize; j++) {
temp = inputweight[i][j] + (stepsize * hinddenDelta[j] * inputnode[i]) +
(momentum * (inputweight[i][j] - oldInputeight[i][j]));
oldInputeight[i][j] = inputweight[i][j];
inputweight[i][j] = temp;
}
}
for(int i=0; i<hinddensize; i++) {
for(int j=0; j<outputsize; j++) {
temp = outputweight[i][j] + (stepsize * outputDelta[j] * hiddennode[i].getActivation()) +
(momentum * (outputweight[i][j] - oldoutputweight[i][j]));
oldoutputweight[i][j] = outputweight[i][j];
outputweight[i][j] = temp;
}
}
// 更新bais
for(int i=0; i<hinddensize; i++) {
temp = hiddennode[i].getBias() + (stepsize * hinddenDelta[i]) +
(momentum * (hiddennode[i].getBias() - hiddennode[i].getOldbais()));
hiddennode[i].setOldbais(hiddennode[i].getBias());
hiddennode[i].setBias(temp);
}
for(int i=0; i<outputsize; i++) {
temp = outputnode[i].getBias() + (stepsize * outputDelta[i]) +
(momentum * (outputnode[i].getBias() - outputnode[i].getOldbais()));
outputnode[i].setOldbais(outputnode[i].getBias());
outputnode[i].setBias(temp);
}
}
private void Forward() {//前向传播
// TODO Auto-generated method stub
double temp;
for(int i=0;i<hinddensize;i++) {
temp=0.0;
for(int j=0;j<inputsize;j++)
temp+=inputnode[j]*inputweight[j][i];
try {
hiddennode[i].setActivation(sigmoid(temp+hiddennode[i].getBias()));
} catch (NullPointerException e) {
// TODO: handle exception
System.out.println(e);
}
}
for(int i=0;i<outputsize;i++) {
temp=0.0;
for(int j=0;j<hinddensize;j++)
try {
temp+=hiddennode[j].getActivation()*outputweight[j][i];
} catch (NullPointerException e) {
// TODO: handle exception
System.out.println(e);
}
try {
outputnode[i].setActivation(sigmoid(temp+outputnode[i].getBias()));
} catch (NullPointerException e) {
// TODO: handle exception
System.out.println(e);
}
}
}
private double sigmoid(double d) {//激活函数
// TODO Auto-generated method stub
return 1/(1 + Math.exp(-1 * d));
}
public static void getAllInput() throws IOException{
NetWork.allInput=new int[40000][784];
FileReader fo=new FileReader("./allInput.txt");
BufferedReader bwo=new BufferedReader(fo);
for(int i=0;i<40000;i++){
String s[]=bwo.readLine().split(",");
for(int j=0;j<784;j++){
NetWork.allInput[i][j]=Integer.valueOf(s[j]);
}
}
bwo.close();
fo.close();
}
private int Max(double[] b) {
int maxIndex = 0; //获取到的最大值的角标
for(int i=0; i<b.length; i++){
if(b[i] > b[maxIndex]){
maxIndex = i;
}
}
return maxIndex;
}
public String recognize(int[] a) throws FileNotFoundException {
// TODO Auto-generated method stub
for (int i = 0; i < a.length; i++) {
inputnode[i]=a[i];
}
File file = new File("./height.txt");
Scanner in = new Scanner(file);
for(int i=0;i<inputweight.length;i++) {
for (int j=0;j<inputweight[0].length;j++) {
inputweight[i][j]=in.nextDouble();
}
}
in.close();
File file2 = new File("./heightbais.txt");
Scanner in2 = new Scanner(file2);
for (int j=0;j<hinddensize;j++) {
{
hiddennode[j].setBias(in2.nextDouble());
}
}
//
File file3 = new File("./outweighbais.txt");
Scanner in3 = new Scanner(file3);
try {
for (int j=0;j<10;j++) {
outputnode[j].setBias(in3.nextDouble());
}
} catch (NullPointerException e) {
System.err.println(e);
}
File file4 = new File("./outweigh.txt");
Scanner in4 = new Scanner(file4);
for(int i=0;i<hinddensize;i++)
for (int j=0;j<outputsize;j++) {
{outputweight[i][j]=in4.nextDouble();
}
}
Forward();
double b[]=new double[10];
for(int p=0;p<10;p++)
{
b[p]= outputnode[p].getActivation();
}
int index=Max(b);
return "数字是 :"+index;
}
}
节点:
package network;
public class node {
private double activation;
private double bias;
private double oldbais;
public node(double a,double b) {
this.activation=a;
this.bias=b;
}
public double getActivation() {
return this.activation;
}
public void setActivation(double activation) {
this.activation = activation;
}
public double getBias() {
return this.bias;
}
public void setBias(double bias) {
this.bias = bias;
}
public double getOldbais() {
return this.oldbais;
}
public void setOldbais(double oldbais) {
this.oldbais = oldbais;
}
public String toString() {
return this.activation + " " + this.bias;
}
}
绘制待识别数字界面:
package main;
import java.awt.*;
import java.awt.event.*;
import java.util.Vector;
import javax.swing.JPanel;
public class mypanel extends JPanel {
private static final long serialVersionUID = 1L;
private Vector<Vector<Point>> FreedomDatas = new Vector<Vector<Point>>();
private Color lineColor = Color.white;
private int lineWidth = 16;
public mypanel()
{ //setBorder(BorderFactory.createLineBorder(Color.BLACK));
addMouseListener(new MouseAdapter()
{
public void mousePressed(MouseEvent e)
{
Point p = new Point(e.getX(),e.getY());
Vector<Point> newLine = new Vector<Point>();
newLine.add(p);
FreedomDatas.add(newLine);
}
public void mouseReleased(MouseEvent e)
{
repaint();
}
});
addMouseMotionListener(new MouseMotionAdapter()
{
public void mouseDragged(MouseEvent e)
{
Point p = new Point(e.getX(),e.getY());
int n = FreedomDatas.size()-1; //拿到最后一条线的位置
Vector<Point> lastLine = FreedomDatas.get(n);
lastLine.add(p);
}
});
}
public void cleanAll()
{
FreedomDatas.clear();
repaint();
}
public void paint(Graphics g)
{
g.fillRect(0, 0, getWidth(), getHeight());
g.setColor(lineColor);
Graphics2D g_2D = (Graphics2D)g;
BasicStroke stroke = new BasicStroke(lineWidth,BasicStroke.CAP_ROUND,BasicStroke.JOIN_ROUND);
g_2D.setStroke(stroke);
Vector<Point> v;
Point s,e;
int i,j,m;
int n = FreedomDatas.size();
for(i=0;i<n;i++)
{
v = FreedomDatas.get(i);
m = v.size()-1;
for(j=0;j<m;j++)
{
s = (Point)v.get(j);
e = (Point)v.get(j+1);
g.drawLine(s.x, s.y, e.x, e.y);
}
}
}
}
抓取绘图界面的数字图像并处理二值化
package imageprocess;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.GraphicsConfiguration;
import java.awt.GraphicsDevice;
import java.awt.GraphicsEnvironment;
import java.awt.HeadlessException;
import java.awt.Image;
import java.awt.Toolkit;
import java.awt.Transparency;
import java.awt.geom.AffineTransform;
import java.awt.image.AffineTransformOp;
import java.awt.image.BufferedImage;
import java.awt.image.CropImageFilter;
import java.awt.image.FilteredImageSource;
import java.awt.image.ImageFilter;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
public class getimage {
public final static int[][] getMatirx(BufferedImage bi) throws IOException{
binaryimage biimg=new binaryimage();//创建二值图类,分别保存二值化后的图片及矩阵
int h=bi.getHeight();
int w=bi.getWidth();
if(h>800||w>800){
bi=scale(bi, 800, 800);
h=bi.getHeight();
w=bi.getWidth();
}
biimg.brmatrix(bi);//二值图保;
int bi_matrix[][]=biimg.brimage;//二值矩阵
int left=0,right=0,top=0,below=0;
int row[]=new int[w];
for(int i=0;i<w;i++){
int s=0;
for(int j=0;j<h;j++){
s=s+bi_matrix[i][j];
}
row[i]=s;
}
int line[]=new int[h];
for(int j=0;j<h;j++){
int s=0;
for(int i=0;i<w;i++){
s=s+bi_matrix[i][j];
}
line[j]=s;
}
for(int i=1;i<w;i++){
if(row[i]>=2){
if(left==0){
left=i;
}
if(right<i){
right=i;}
}
}
for(int i=1;i<h;i++){
if(line[i]>=2){
if(top==0){
top=i;
}
if(below<i){
below=i;}
}
}
int new_h;int new_w;
new_h=(28-(below-top)%28)-top+below;
new_w=(28-(right-left)%28)-left+right;
if((new_h/new_w)>=2){
new_w=(new_h/2%28)+new_h/2;
top=top-(28-(below-top)%28)/2;
left=left-(new_w-right+left)/2;
}
else{
top=top-(28-(below-top)%28)/2;
left=left-(28-(right-left)%28)/2;
}
biimg.image=cut(biimg.image, left, top, w, h, new_w, new_h);
int InputMatrix[][]=new int[28][28];
InputMatrix=cut2(biimg.image, 28, 28);
return InputMatrix;
}
public final static BufferedImage scale(BufferedImage bi, int height, int width) {
double ratio = 0.0; // 缩放比例
Image temp = bi.getScaledInstance(width, height, Image.SCALE_SMOOTH);
if ((bi.getHeight() > height) || (bi.getWidth() > width)) {
if (bi.getHeight() > bi.getWidth()) {
ratio = (new Integer(height)).doubleValue()/ bi.getHeight();
} else {
ratio = (new Integer(width)).doubleValue()/ bi.getWidth();
}
AffineTransformOp op = new AffineTransformOp(AffineTransform
.getScaleInstance(ratio, ratio), null);
temp = op.filter(bi, null);
}
return toBufferedImage(temp);
}
public final static BufferedImage toBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
image = new ImageIcon(image).getImage();
boolean hasAlpha = false;
BufferedImage bimage = null;
GraphicsEnvironment ge = GraphicsEnvironment
.getLocalGraphicsEnvironment();
try {
int transparency = Transparency.OPAQUE;
if (hasAlpha) {
transparency = Transparency.BITMASK;
}
GraphicsDevice gs = ge.getDefaultScreenDevice();
GraphicsConfiguration gc = gs.getDefaultConfiguration();
bimage = gc.createCompatibleImage(image.getWidth(null),
image.getHeight(null), transparency);
} catch (HeadlessException e) {
}
if (bimage == null) {
int type = BufferedImage.TYPE_INT_RGB;
if (hasAlpha) {
type = BufferedImage.TYPE_INT_ARGB;
}
bimage = new BufferedImage(image.getWidth(null),
image.getHeight(null), type);
}
Graphics g = bimage.createGraphics();
g.drawImage(image, 0, 0, null);
g.dispose();
return bimage;
}
public final static BufferedImage cut(BufferedImage bi,int x, int y,int w,int h,int new_w,int new_h) throws IOException {
Image image = bi.getScaledInstance(w, h,
Image.SCALE_DEFAULT);
ImageFilter cropFilter = new CropImageFilter(x, y, new_w, new_h);
Image img = Toolkit.getDefaultToolkit().createImage(
new FilteredImageSource(image.getSource(),
cropFilter));
BufferedImage tag = new BufferedImage(new_w, new_h, BufferedImage.TYPE_INT_RGB);
Graphics g = tag.getGraphics();
g.drawImage(img, 0, 0, new_w, new_h, null); // 绘制切割后的图
g.dispose();
ImageIO.write(tag, "jpg", new File("./after.jpg"));
return tag;
}
public final static int[][] cut2( BufferedImage bi,int rows, int cols) {
int InputMatrix[][]=new int[rows][cols];
try {
int srcWidth = bi.getHeight();
int srcHeight = bi.getWidth();
if (srcWidth > 0 && srcHeight > 0) {
Image img;
ImageFilter cropFilter;
Image image = bi.getScaledInstance(srcWidth, srcHeight, Image.SCALE_DEFAULT);
int destWidth = srcWidth;
int destHeight = srcHeight;
if (srcWidth % cols == 0) {
destWidth = srcWidth / cols;
} else {
destWidth = (int) Math.floor(srcWidth / cols) + 1;
}
if (srcHeight % rows == 0) {
destHeight = srcHeight / rows;
} else {
destHeight = (int) Math.floor(srcWidth / rows) + 1;
}
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
cropFilter = new CropImageFilter(j * destWidth, i * destHeight,
destWidth, destHeight);
img = Toolkit.getDefaultToolkit().createImage(
new FilteredImageSource(image.getSource(),
cropFilter));
BufferedImage tag = new BufferedImage(destWidth,
destHeight, BufferedImage.TYPE_INT_RGB);
Graphics g = tag.getGraphics();
g.drawImage(img, 0, 0, null); // 绘制缩小后的图
g.dispose();
if(IsBlank(tag,destWidth,destHeight)==true){
InputMatrix[i][j]=1;
}
else{
InputMatrix[i][j]=0;
}
}
}
}
} catch (Exception e) {
e.printStackTrace();
}
return InputMatrix;
}
public final static boolean IsBlank(BufferedImage tag,int destWidth,int destHeight){
boolean blank=true;
int gray[][]=new int[destWidth][destHeight];
for (int x = 0; x < destWidth; x++) {
for (int y = 0; y < destHeight; y++) {
gray[x][y]=getGray(tag.getRGB(x, y));
}
}
for(int i=0;i<destWidth&&blank==true;i++){
for(int j=0;j<destHeight&&blank==true;j++){
if(gray[i][j]==0){
blank=false;
}
}
}
return blank;
}
public final static int getGray(int rgb){
String str=Integer.toHexString(rgb);
int r=Integer.parseInt(str.substring(2,4),16);
int g=Integer.parseInt(str.substring(4,6),16);
int b=Integer.parseInt(str.substring(6,8),16);
Color c=new Color(rgb);
r=c.getRed();
g=c.getGreen();
b=c.getBlue();
int top=(r+g+b)/3;
return (int)(top);
}
}
package imageprocess;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
public class binaryimage {
private int gray[][]=null;//存储图像灰度值
public int brimage[][]=null;//存储图像二值化后灰度值
private int gra[][]=null;//给图像添白框,方便去噪
public BufferedImage image;
public void brmatrix(BufferedImage bi) throws IOException {
int h=bi.getHeight();//获取图像的高
int w=bi.getWidth();//获取图像的宽
gray=new int[w][h];
brimage=new int[w][h];
for (int x = 0; x < w; x++) {
for (int y = 0; y < h; y++) {
gray[x][y]=getGray(bi.getRGB(x, y));
}
}
Brighter(gray,w,h);
gra=new int[w+4][h+4];
for(int i=0;i<w+4;i++){
for(int j=0; j<h+4; j++){
if(i>1&&i<w-1&&j>1&&j<h-1){
gra[i][j]=gray[i-2][j-2];
}
else{
gra[i][j]=0;}
}
}
BufferedImage nbi=new BufferedImage(w,h,BufferedImage.TYPE_BYTE_BINARY);
int SW=125;
for (int x = 0; x < w; x++) {
for (int y = 0; y<h; y++) {
if(getAverageColor(gra, x, y, w, h)>SW){
int max=new Color(255,255,255).getRGB();
nbi.setRGB(x, y, max);
brimage[x][y]=1;
}else{
int min=new Color(0,0,0).getRGB();
nbi.setRGB(x, y, min);
brimage[x][y]=0;
}
}
}
this.image=nbi;
System.gc();
}
private int getGray(int rgb){
String str=Integer.toHexString(rgb);
int r=Integer.parseInt(str.substring(2,4),16);
int g=Integer.parseInt(str.substring(4,6),16);
int b=Integer.parseInt(str.substring(6,8),16);
Color c=new Color(rgb);
r=c.getRed();
g=c.getGreen();
b=c.getBlue();
int top=(r+g+b)/3;
return (int)(top);
}
private int getAverageColor(int[][] gray, int x, int y, int w, int h)
{
int rs=0;
for(int i=0;i<5;i++){
for(int j=0;j<5;j++){
rs=gray[x+i][y+j]+rs;}
}
return rs / 25;
}
public static void Brighter(int[][]gray,int w,int h){
for(int x=0;x<w;x++){
for(int y=0;y<h;y++){
gray[x][y]=(int) Math.floor(gray[x][y]*1.25);
if(gray[x][y]>255){
gray[x][y]=255;
}
}
}
}
}
训练界面:设定迭代次数,移动步长,隐层数量,动量调节数值
package main;
import java.awt.Color;
import java.awt.Cursor;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowEvent;
import java.awt.event.WindowListener;
import java.io.File;
import java.io.FileWriter;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JTextField;
import network.NetWork;
public class trainwin extends JFrame implements WindowListener {
private static final long serialVersionUID = 1L;
private JTextField Limit;
private JTextField LearningRate;
private JTextField hLNeurons;
private JButton txttrain;
private JLabel lerror;
private JLabel lters;
private JLabel lsuccess;
private JLabel MSE;
private JLabel emnue;
private JTextField emnuefile;
private NetWork nn;
public trainwin() {
setTitle("自我训练");
setSize(new Dimension(400, 440));
setLocationRelativeTo(null);
setDefaultCloseOperation(EXIT_ON_CLOSE);
setResizable(false);
setLayout(null);
createRightSide();
addWindowListener(this);
setVisible(true);
}
private void createRightSide() {
JLabel lblIterLimit = new JLabel("迭代次数");
lblIterLimit.setBounds(70, 10, 100, 50);
JLabel lblLearningRate = new JLabel("学习率");
lblLearningRate.setBounds(70, 40, 100, 50);
JLabel lblHLNeurons = new JLabel("隐层数量");
lblHLNeurons.setBounds(70, 70, 100, 50);
emnue = new JLabel("动能调节");
emnue.setBounds(70, 100, 100, 50);
emnue.setCursor(new Cursor(Cursor.TEXT_CURSOR));
JLabel lblItersTxt = new JLabel("迭代结束:");
lblItersTxt.setBounds(70, 230, 150, 50);
JLabel lblSuccessTxt = new JLabel("成功率:");
lblSuccessTxt.setBounds(70, 270, 100, 50);
JLabel lblMSETxt = new JLabel("均方误差:");
lblMSETxt.setBounds(70, 310, 150, 50);
Limit = new JTextField();
Limit.setBounds(165, 25, 100, 20);
Limit.setCursor(new Cursor(Cursor.TEXT_CURSOR));
LearningRate = new JTextField();
LearningRate.setBounds(165, 55, 100, 20);
LearningRate.setCursor(new Cursor(Cursor.TEXT_CURSOR));
hLNeurons = new JTextField();
hLNeurons.setBounds(165, 85, 100, 20);
hLNeurons.setCursor(new Cursor(Cursor.TEXT_CURSOR));
emnuefile=new JTextField();
emnuefile.setBounds(165, 115, 100, 20);
emnuefile.setCursor(new Cursor(Cursor.TEXT_CURSOR));
txttrain = new JButton("训练开始");
txttrain.setBounds(100, 150, 100, 30);
txttrain.setFocusPainted(false);
txttrain.setCursor(new Cursor(Cursor.HAND_CURSOR));
txttrain.addActionListener(new TrainListener());
lerror = new JLabel("");
lerror.setBounds(70, 190, 300, 30);
lerror.setForeground(Color.RED);
lters = new JLabel("");
lters.setBounds(160, 230, 100, 50);
lters.setForeground(Color.BLUE);
lsuccess = new JLabel("");
lsuccess.setBounds(160, 270, 100, 50);
lsuccess.setForeground(Color.BLUE);
MSE = new JLabel("");
MSE.setBounds(160, 310, 100, 50);
MSE.setForeground(Color.BLUE);
getContentPane().add(lblIterLimit);
getContentPane().add(lblLearningRate);
getContentPane().add(lblHLNeurons);
getContentPane().add(Limit);
getContentPane().add(LearningRate);
getContentPane().add(hLNeurons);
getContentPane().add(txttrain);
getContentPane().add(lerror);
getContentPane().add(lblItersTxt);
getContentPane().add(lblSuccessTxt);
getContentPane().add(lblMSETxt);
getContentPane().add(lters);
getContentPane().add(lsuccess);
getContentPane().add(MSE);
getContentPane().add(emnue);
getContentPane().add(emnuefile);
}
private class TrainListener implements ActionListener {
public void actionPerformed(ActionEvent e) {
new Thread() {
public void run() {
txttrain.setEnabled(false);
txttrain.setText("训练中...");
lters.setText("");
lsuccess.setText("");
MSE.setText("");
if (!Limit.getText().matches("[1-9][0-9]*"))
lerror.setText("请输入迭代次数.");
else if (!LearningRate.getText().matches("[0-9]*\\.[0-9]+"))
lerror.setText("请输入可用的学习率.");
else if (!hLNeurons.getText().matches("[1-9][0-9]*"))
lerror.setText("请输入有隐层神经元数量 .");
else if (!emnuefile.getText().matches("[0-9]*\\.[0-9]+"))
lerror.setText("请输入有效动量调节值 .");
else {
int iter=Integer.parseInt(Limit.getText());
double rate = Double.parseDouble(LearningRate.getText());
int hide= Integer.parseInt(hLNeurons.getText());
double emnue= Double.parseDouble(emnuefile.getText());
nn = new NetWork(iter, rate, hide,0.05,emnue);
try {
nn.initNetwork();
nn.train();
lters.setText(String.valueOf(iter));
lsuccess.setText(String.format("%f", nn.getsuccess() * 100));
MSE.setText( String.valueOf(nn.geterror()));
FileWriter pf=new FileWriter(new File("./parameter.txt"));
pf.write(String.format("%d\n",iter));
pf.write(String.format("%f\n",rate));
pf.write(String.format("%d\n",hide ));
pf.write(String.format("%f\n",emnue));
pf.close();
} catch (Exception e1) { e1.printStackTrace(); }
}
txttrain.setEnabled(true);
txttrain.setText("训练");
}
}.start();
}
}
public void windowClosed(WindowEvent e) {}
public void windowOpened(WindowEvent e) {}
public void windowIconified(WindowEvent e) {}
public void windowDeiconified(WindowEvent e) {}
public void windowActivated(WindowEvent e) {}
public void windowDeactivated(WindowEvent e) {}
public static void main(String[] args) {
new trainwin();
}
@Override
public void windowClosing(WindowEvent windowevent) {
// TODO Auto-generated method stub
}
}
训练数据及训练10次,学习率0.3,隐层50和动量0.5的个参数数据见: