public class Test { public static void main ( String args[ ] ) throws Exception { ArrayList< ArrayList< Double> > alllist = new ArrayList < ArrayList< Double> > ( ) ; ArrayList< String> outlist = new ArrayList < String> ( ) ; int in_num = 0 , out_num = 0 ; DataUtil dataUtil = new DataUtil ( ) ; dataUtil. NormalizeData ( "src/bp/train.txt" ) ; dataUtil. SetTypeNum ( 3 ) ; dataUtil. ReadFile ( "src/bp/train.txt" , "," , 0 ) ; in_num = dataUtil. GetInNum ( ) ; out_num = dataUtil. GetOutNum ( ) ; alllist = dataUtil. GetList ( ) ; outlist = dataUtil. GetOutList ( ) ; System. out. print ( "分类的类型:" ) ; for ( int i = 0 ; i< outlist. size ( ) ; i++ ) System. out. print ( outlist. get ( i) + " " ) ; System. out. println ( ) ; System. out. println ( "训练集的数量:" + alllist. size ( ) ) ; BPNN bpnn = new BPNN ( ) ; System. out. println ( "Train Start!" ) ; System. out. println ( "............." ) ; bpnn. Train ( in_num, out_num, alllist) ; System. out. println ( "Train End!" ) ; DataUtil testUtil = new DataUtil ( ) ; testUtil. NormalizeData ( "src/bp/test.txt" ) ; testUtil. SetTypeNum ( 3 ) ; testUtil. ReadFile ( "src/bp/test.txt" , "," , 1 ) ; ArrayList< ArrayList< Double> > testList = new ArrayList < ArrayList< Double> > ( ) ; ArrayList< ArrayList< Double> > resultList = new ArrayList < ArrayList< Double> > ( ) ; ArrayList< String> normallist = new ArrayList < String> ( ) ; ArrayList< String> resultlist = new ArrayList < String> ( ) ; double right = 0 ; int type_num = 0 ; double all_num = 0 ; type_num = outlist. size ( ) ; testList = testUtil. GetList ( ) ; normallist = testUtil. GetCheckList ( ) ; int errorcount= 0 ; resultList = bpnn. ForeCast ( testList) ; all_num= resultList. size ( ) ; for ( int i = 0 ; i < resultList. size ( ) ; i++ ) { String checkString = "unknow" ; for ( int j = 0 ; j < type_num; j++ ) { if ( resultList. get ( i) . get ( j) == 1.0 ) { checkString = outlist. get ( j) ; resultlist. add ( checkString) ; } } if ( checkString. equals ( normallist. get ( i) ) ) right++ ; } System. out. println ( "测试集的数量:" + ( new Double ( all_num) ) . intValue ( ) ) ; System. out. println ( "分类正确的数量:" + ( new Double ( right) ) . intValue ( ) ) ; System. out. println ( "算法的分类正确率为:" + right/ all_num) ; System. out. println ( "分类结果存储在:E:\\BP_data\\result.txt" ) ; }
}
package bp;
import java. io. BufferedReader;
import java. io. BufferedWriter;
import java. io. File;
import java. io. FileInputStream;
import java. io. FileWriter;
import java. io. IOException;
import java. io. InputStreamReader;
import java. util. ArrayList; class DataUtil { private ArrayList< ArrayList< Double> > alllist = new ArrayList < ArrayList< Double> > ( ) ; private ArrayList< String> outlist = new ArrayList < String> ( ) ; private ArrayList< String> checklist = new ArrayList < String> ( ) ; private int in_num = 0 ; private int out_num = 0 ; private int type_num = 0 ; private double [ ] [ ] nom_data; private int in_data_num = 0 ; public int GetTypeNum ( ) { return type_num; } public void SetTypeNum ( int type_num) { this . type_num = type_num; } public int GetInNum ( ) { return in_num; } public int GetOutNum ( ) { return out_num; } public ArrayList< ArrayList< Double> > GetList ( ) { return alllist; } public ArrayList< String> GetOutList ( ) { return outlist; } public ArrayList< String> GetCheckList ( ) { return checklist; } public double [ ] [ ] GetMaxMin ( ) { return nom_data; } public void ReadFile ( String filepath, String sep, int flag) throws Exception { ArrayList< Double> everylist = new ArrayList < Double> ( ) ; int readflag = flag; String encoding = "GBK" ; File file = new File ( filepath) ; if ( file. isFile ( ) && file. exists ( ) ) { InputStreamReader read = new InputStreamReader ( new FileInputStream ( file) , encoding) ; BufferedReader bufferedReader = new BufferedReader ( read) ; String lineTxt = null; while ( ( lineTxt = bufferedReader. readLine ( ) ) != null) { int in_number = 0 ; String splits[ ] = lineTxt. split ( sep) ; if ( readflag == 0 ) { for ( int i = 0 ; i < splits. length; i++ ) try { everylist. add ( Normalize ( Double. valueOf ( splits[ i] ) , nom_data[ i] [ 0 ] , nom_data[ i] [ 1 ] ) ) ; in_number++ ; } catch ( Exception e) { if ( ! outlist. contains ( splits[ i] ) ) outlist. add ( splits[ i] ) ; for ( int k = 0 ; k < type_num; k++ ) { everylist. add ( 0.0 ) ; } everylist. set ( in_number + outlist. indexOf ( splits[ i] ) , 1.0 ) ; } } else if ( readflag == 1 ) { for ( int i = 0 ; i < splits. length; i++ ) try { everylist. add ( Normalize ( Double. valueOf ( splits[ i] ) , nom_data[ i] [ 0 ] , nom_data[ i] [ 1 ] ) ) ; in_number++ ; } catch ( Exception e) { checklist. add ( splits[ i] ) ; } } alllist. add ( everylist) ; in_num = in_number; out_num = type_num; everylist = new ArrayList < Double> ( ) ; everylist. clear ( ) ; } bufferedReader. close ( ) ; } } public void WriteFile ( String filepath, ArrayList< ArrayList< Double> > list, int in_number, ArrayList< String> resultlist) throws IOException{ File file = new File ( filepath) ; FileWriter fw = null; BufferedWriter writer = null; try { fw = new FileWriter ( file) ; writer = new BufferedWriter ( fw) ; System. out. println ( resultlist. size ( ) ) ; for ( int i= 0 ; i< resultlist. size ( ) - 1 ; i++ ) { for ( int j= 0 ; j< in_number; j++ ) writer. write ( list. get ( i) . get ( j) + "," ) ; writer. write ( resultlist. get ( i) ) ; writer. newLine ( ) ; } writer. flush ( ) ; } catch ( IOException e) { e. printStackTrace ( ) ; } finally { writer. close ( ) ; fw. close ( ) ; } } public void NormalizeData ( String filepath) throws IOException{ GetBeforIn ( filepath) ; int flag= 1 ; nom_data = new double [ in_data_num] [ 2 ] ; String encoding = "GBK" ; File file = new File ( filepath) ; if ( file. isFile ( ) && file. exists ( ) ) { InputStreamReader read = new InputStreamReader ( new FileInputStream ( file) , encoding) ; BufferedReader bufferedReader = new BufferedReader ( read) ; String lineTxt = null; while ( ( lineTxt = bufferedReader. readLine ( ) ) != null) { String splits[ ] = lineTxt. split ( "," ) ; for ( int i = 0 ; i < splits. length- 1 ; i++ ) { if ( flag== 1 ) { nom_data[ i] [ 0 ] = Double. valueOf ( splits[ i] ) ; nom_data[ i] [ 1 ] = Double. valueOf ( splits[ i] ) ; } else { if ( Double. valueOf ( splits[ i] ) > nom_data[ i] [ 0 ] ) nom_data[ i] [ 0 ] = Double. valueOf ( splits[ i] ) ; if ( Double. valueOf ( splits[ i] ) < nom_data[ i] [ 1 ] ) nom_data[ i] [ 1 ] = Double. valueOf ( splits[ i] ) ; } } flag= 0 ; } bufferedReader. close ( ) ; } } public void GetBeforIn ( String filepath) throws IOException{ String encoding = "GBK" ; File file = new File ( filepath) ; if ( file. isFile ( ) && file. exists ( ) ) { InputStreamReader read = new InputStreamReader ( new FileInputStream ( file) , encoding) ; BufferedReader beforeReader = new BufferedReader ( read) ; String beforetext = beforeReader. readLine ( ) ; String splits[ ] = beforetext. split ( "," ) ; in_data_num = splits. length- 1 ; beforeReader. close ( ) ; } } public double Normalize ( double x, double max, double min) { double y = 0.1 + 0.8 * ( x- min) / ( max- min) ; return y; }
}
class BPNN { private static int NodeNum = 10 ; private static final int ADJUST = 5 ; private static final int MaxTrain = 2000 ; private static final double ACCU = 0.015 ; private double ETA_W = 0.5 ; private double ETA_T = 0.5 ; private double accu; private int in_num; private int hd_num; private int out_num; private ArrayList< ArrayList< Double> > list = new ArrayList < > ( ) ; private double [ ] [ ] in_hd_weight; private double [ ] [ ] hd_out_weight; private double [ ] in_hd_th; private double [ ] hd_out_th; private double [ ] [ ] out; private double [ ] [ ] delta; public int GetMaxNum ( ) { return Math. max ( Math. max ( in_num, hd_num) , out_num) ; } public void SetEtaW ( ) { ETA_W = 0.5 ; } public void SetEtaT ( ) { ETA_T = 0.5 ; } public void Train ( int in_number, int out_number, ArrayList< ArrayList< Double> > arraylist) throws IOException { list = arraylist; in_num = in_number; out_num = out_number; GetNums ( in_num, out_num) ; InitNetWork ( ) ; int datanum = list. size ( ) ; int createsize = GetMaxNum ( ) ; out = new double [ 3 ] [ createsize] ; for ( int iter = 0 ; iter < MaxTrain; iter++ ) { for ( int cnd = 0 ; cnd < datanum; cnd++ ) { for ( int i = 0 ; i < in_num; i++ ) { out[ 0 ] [ i] = list. get ( cnd) . get ( i) ; } Forward ( ) ; Backward ( cnd) ; } System. out. println ( "This is the " + ( iter + 1 ) + " th trainning NetWork !" ) ; accu = GetAccu ( ) ; System. out. println ( "All Samples Accuracy is " + accu) ; if ( accu < ACCU) break ; } } public void GetNums ( int in_number, int out_number) { in_num = in_number; out_num = out_number; hd_num = ( int ) Math. sqrt ( in_num + out_num) + ADJUST; if ( hd_num > NodeNum) hd_num = NodeNum; } public void InitNetWork ( ) { in_hd_weight = new double [ in_num] [ hd_num] ; for ( int i = 0 ; i < in_num; i++ ) for ( int j = 0 ; j < hd_num; j++ ) { int flag = 1 ; if ( ( new Random ( ) . nextInt ( 2 ) ) == 1 ) flag = 1 ; else flag = - 1 ; in_hd_weight[ i] [ j] = ( new Random ( ) . nextDouble ( ) / 2 ) * flag; } hd_out_weight = new double [ hd_num] [ out_num] ; for ( int i = 0 ; i < hd_num; i++ ) for ( int j = 0 ; j < out_num; j++ ) { int flag = 1 ; if ( ( new Random ( ) . nextInt ( 2 ) ) == 1 ) flag = 1 ; else flag = - 1 ; hd_out_weight[ i] [ j] = ( new Random ( ) . nextDouble ( ) / 2 ) * flag; } in_hd_th = new double [ hd_num] ; for ( int k = 0 ; k < hd_num; k++ ) in_hd_th[ k] = 0 ; hd_out_th = new double [ out_num] ; for ( int k = 0 ; k < out_num; k++ ) hd_out_th[ k] = 0 ; } public double GetError ( int cnd) { double ans = 0 ; for ( int i = 0 ; i < out_num; i++ ) ans += 0.5 * ( out[ 2 ] [ i] - list. get ( cnd) . get ( in_num + i) ) * ( out[ 2 ] [ i] - list. get ( cnd) . get ( in_num + i) ) ; return ans; } public double GetAccu ( ) { double ans = 0 ; int num = list. size ( ) ; for ( int i = 0 ; i < num; i++ ) { int m = in_num; for ( int j = 0 ; j < m; j++ ) out[ 0 ] [ j] = list. get ( i) . get ( j) ; Forward ( ) ; int n = out_num; for ( int k = 0 ; k < n; k++ ) ans += 0.5 * ( list. get ( i) . get ( in_num + k) - out[ 2 ] [ k] ) * ( list. get ( i) . get ( in_num + k) - out[ 2 ] [ k] ) ; } return ans / num; } public void Forward ( ) { for ( int j = 0 ; j < hd_num; j++ ) { double v = 0 ; for ( int i = 0 ; i < in_num; i++ ) v += in_hd_weight[ i] [ j] * out[ 0 ] [ i] ; v += in_hd_th[ j] ; out[ 1 ] [ j] = Sigmoid ( v) ; } for ( int j = 0 ; j < out_num; j++ ) { double v = 0 ; for ( int i = 0 ; i < hd_num; i++ ) v += hd_out_weight[ i] [ j] * out[ 1 ] [ i] ; v += hd_out_th[ j] ; out[ 2 ] [ j] = Sigmoid ( v) ; } } public void Backward ( int cnd) { CalcDelta ( cnd) ; UpdateNetWork ( ) ; } public void CalcDelta ( int cnd) { int createsize = GetMaxNum ( ) ; delta = new double [ 3 ] [ createsize] ; for ( int i = 0 ; i < out_num; i++ ) { delta[ 2 ] [ i] = ( list. get ( cnd) . get ( in_num + i) - out[ 2 ] [ i] ) * SigmoidDerivative ( out[ 2 ] [ i] ) ; } for ( int i = 0 ; i < hd_num; i++ ) { double t = 0 ; for ( int j = 0 ; j < out_num; j++ ) t += hd_out_weight[ i] [ j] * delta[ 2 ] [ j] ; delta[ 1 ] [ i] = t * SigmoidDerivative ( out[ 1 ] [ i] ) ; } } public void UpdateNetWork ( ) { for ( int i = 0 ; i < hd_num; i++ ) { for ( int j = 0 ; j < out_num; j++ ) { hd_out_weight[ i] [ j] += ETA_W * delta[ 2 ] [ j] * out[ 1 ] [ i] ; } } for ( int i = 0 ; i < out_num; i++ ) hd_out_th[ i] += ETA_T * delta[ 2 ] [ i] ; for ( int i = 0 ; i < in_num; i++ ) { for ( int j = 0 ; j < hd_num; j++ ) { in_hd_weight[ i] [ j] += ETA_W * delta[ 1 ] [ j] * out[ 0 ] [ i] ; } } for ( int i = 0 ; i < hd_num; i++ ) in_hd_th[ i] += ETA_T * delta[ 1 ] [ i] ; } public int Sign ( double x) { if ( x > 0 ) return 1 ; else if ( x < 0 ) return - 1 ; else return 0 ; } public double Maximum ( double x, double y) { if ( x >= y) return x; else return y; } public double Minimum ( double x, double y) { if ( x <= y) return x; else return y; } public double Sigmoid ( double x) { return ( double ) ( 1 / ( 1 + Math. exp ( - x) ) ) ; } public double SigmoidDerivative ( double y) { return ( double ) ( y * ( 1 - y) ) ; } public double TSigmoid ( double x) { return ( double ) ( ( 1 - Math. exp ( - x) ) / ( 1 + Math. exp ( - x) ) ) ; } public double TSigmoidDerivative ( double y) { return ( double ) ( 1 - ( y * y) ) ; } public ArrayList< ArrayList< Double> > ForeCast ( ArrayList< ArrayList< Double> > arraylist) { ArrayList< ArrayList< Double> > alloutlist = new ArrayList < > ( ) ; ArrayList< Double> outlist = new ArrayList < Double> ( ) ; int datanum = arraylist. size ( ) ; for ( int cnd = 0 ; cnd < datanum; cnd++ ) { for ( int i = 0 ; i < in_num; i++ ) out[ 0 ] [ i] = arraylist. get ( cnd) . get ( i) ; Forward ( ) ; for ( int i = 0 ; i < out_num; i++ ) { if ( out[ 2 ] [ i] > 0 && out[ 2 ] [ i] < 0.5 ) out[ 2 ] [ i] = 0 ; else if ( out[ 2 ] [ i] > 0.5 && out[ 2 ] [ i] < 1 ) { out[ 2 ] [ i] = 1 ; } outlist. add ( out[ 2 ] [ i] ) ; } alloutlist. add ( outlist) ; outlist = new ArrayList < Double> ( ) ; outlist. clear ( ) ; } return alloutlist; } }