神經(jīng)網(wǎng)絡(luò)不會舉一反三,郁悶。。。
            1 //////////////////////////////////////////////////////////////////////////
            2 // File Name: pnn.cpp
            3 // Author:    Ruoruo(du#in.tum.de)
            4 //////////////////////////////////////////////////////////////////////////
            5 #include "stdafx.h"
            6 #include "cv.h"
            7 #include "highgui.h"
            8 #include <ml.h>
            9 #include <time.h>
           10 #include <ctype.h>
           11 #include <vector>
           12 #include <math.h> 
           13 #include <iostream>
           14 using namespace std;
           15 
           16 static CvScalar colors[] = 
           17     {
           18         {{0,0,255}},
           19         {{0,128,255}},
           20         {{0,255,255}},
           21         {{0,255,0}},
           22         {{255,128,0}},
           23         {{255,255,0}},
           24         {{255,0,0}},
           25         {{255,0,255}}
           26     };
           27 
           28 int main( int argc, char** argv )
           29 {
           30     vector<float> point;
           31     vector<float> result;
           32 
           33     float p[10= { 1.32.7,
           34                     1.53.0,
           35                     1.72.8,
           36                     1.62.6,
           37                     1.22.9 };
           38     float res[5= { 0,0,0,0,0 };
           39     int i;
           40     for(i=0; i<10; i++)
           41     {
           42         point.push_back(p[i]);
           43         if(i<5) result.push_back(res[i]);
           44     }
           45 
           46     CvMat* input = cvCreateMat( 52, CV_32FC1 );
           47     cvInitMatHeader( input, 52, CV_32FC1, p );
           48     CvMat* output = cvCreateMat( 51, CV_32FC1 );
           49     cvInitMatHeader( output, 51, CV_32FC1, res );
           50     IplImage* img = cvCreateImage(cvSize(450450), IPL_DEPTH_8U, 3);
           51     img->origin = 1;
           52     for(i= 0; i<5; i++)
           53     {
           54         cvCircle(img, cvPoint((int)(p[i*2]*100), (int)(p[i*2+1]* 100)), 5, colors[(int)res[i]%8], 1, CV_AA, 0);
           55     }
           56     
           57     int layer_num[3= { 241 };
           58     CvMat* layer_size = cvCreateMatHeader( 13, CV_32S );
           59     cvInitMatHeader( layer_size, 13, CV_32S, layer_num );
           60     CvANN_MLP pnn;
           61     pnn.create( layer_size, CvANN_MLP::SIGMOID_SYM, 11 );
           62     CvANN_MLP_TrainParams params;
           63     params.term_crit = cvTermCriteria( CV_TERMCRIT_ITER | CV_TERMCRIT_EPS, \
           64         3000.0000001 );
           65     params.train_method = 0;
           66     params.bp_dw_scale = 0.1;
           67     params.bp_moment_scale = 0.1;
           68     cout<<"begin training"<<endl;
           69     pnn.train( input, output, 00params );
           70     cout<<"end training"<<endl;
           71     pnn.save( "pNN_DATA.xml" );
           72 
           73     //begin to test
           74     float testp[24= { 1.42.75,
           75                         4.30.2
           76                         4.24.3,
           77                         1.452.85,
           78                         4.20.4,
           79                         4.14.0,
           80                         4.30.5,
           81                         4.04.2,
           82                         1.52.7
           83                         4.14.2
           84                         1.62.7,
           85                         4.00.3 };
           86     /*float testp[24] = { 1.4, 2.75,
           87                         4.3, 0.2, 
           88                         4.2, 4.3,
           89                         1.45, 2.85,
           90                         4.25, 0.3,
           91                         4.25, 4.25,
           92                         3.5, 1.2,
           93                         3.0, 3.7,
           94                         4.0, 2.7, 
           95                         0.2, 0.2, 
           96                         2.8, 2.7,
           97                         2.7, 2.8 };*/
           98     CvMat* test_point = cvCreateMat( 12, CV_32FC1 );    
           99     CvMat* test_result = cvCreateMat( 11, CV_32FC1 );
          100     CvFont font;
          101     double hScale=0.5;
          102     double vScale=0.5;
          103     int lineWidth=1;
          104     cvInitFont(&font, CV_FONT_HERSHEY_COMPLEX|CV_FONT_ITALIC, hScale,vScale,0,lineWidth);
          105 
          106     for(i= 0; i<12; i++)
          107     {
          108         cvSetReal2D( test_point, 00, testp[2*i] );
          109         cvSetReal2D( test_point, 01, testp[2*i+1] );
          110         pnn.predict(test_point, test_result);
          111         cout<<cvmGet(test_result,0,0)<<endl;
          112 
          113         float delta = 1;
          114         int best_class = 0;
          115         int max_class = 0;
          116         for(int ii=0; ii<result.size(); ii++){
          117             if(fabs(cvmGet(test_result,0,0- (float)result[ii])<delta){
          118                 delta = fabs(cvmGet(test_result,0,0- (float)result[ii]);
          119                 best_class = result[ii];
          120             }
          121             if(result[ii]>=max_class)
          122                 max_class = result[ii];
          123         }
          124 
          125         point.push_back(testp[2*i]);
          126         point.push_back(testp[2*i+1]);
          127 
          128         if( delta>0.06 ){
          129             int new_result = max_class+1;
          130             cvmSet( test_result,0,0,new_result );
          131             result.push_back((float)new_result );
          132         }
          133         else{
          134             cvmSet( test_result,0,0,best_class );
          135             result.push_back((float)best_class );
          136         }
          137 
          138         int new_point_size = point.size();
          139         int new_result_size = result.size();
          140             
          141         CvMat* input = cvCreateMat( new_result_size, 2, CV_32FC1 );
          142         CvMat* output = cvCreateMat( new_result_size, 1, CV_32FC1 );
          143 
          144         for(int ii=0; ii<new_result_size; ii++)
          145         {
          146             cvmSet( input, ii, 0, point[2*ii]);
          147             cvmSet( input, ii, 1, point[2*ii+1]);
          148             cvmSet( output, ii, 0, result[ii]);
          149         }
          150         //cout<<"begin training again"<<endl;
          151         pnn.train( input, output, 00params );
          152         //cout<<"end training"<<endl;
          153 
          154         cvCircle( img, cvPoint((int)(testp[i*2]*100), (int)(testp[i*2+1]* 100)), 0, colors[(int)cvmGet(test_result,0,0)%8], 10, CV_AA, 0 );
          155 
          156         char buffer[10];
          157         _itoa(i,buffer,10);
          158         string point_id(buffer);
          159         cvPutText(img, point_id.c_str(), cvPoint(testp[2*i]*100,testp[2*i+1]*100), &font, cvScalar(255,255,255));
          160 
          161         cout<<i<<""<<"("<<testp[i*2]<<""<<testp[i*2+1]<<")"<<"\t"<<cvmGet(test_result,0,0)<<endl;
          162     }
          163 
          164     cvNamedWindow( "Coordinates" , 1 ); 
          165     cvShowImage( "Coordinates" ,img);
          166 
          167     cvWaitKey( 0 );
          168 
          169     cvDestroyWindow("Coordinates");
          170     cvReleaseImage(&img);
          171 
          172     return 0;
          173 }

          只有注冊用戶登錄后才能發(fā)表評論。


          網(wǎng)站導(dǎo)航:
           

          posts - 9, comments - 0, trackbacks - 0, articles - 0

          Copyright © 近似凱珊卓

          主站蜘蛛池模板: 和静县| 康乐县| 广东省| 佳木斯市| 尤溪县| 镇坪县| 遵义市| 平塘县| 吉林市| 高密市| 宜兰市| 如东县| 左权县| 泰和县| 崇礼县| 皮山县| 洛扎县| 青海省| 东兰县| 临泉县| 轮台县| 安乡县| 长宁区| 得荣县| 安庆市| 石河子市| 石嘴山市| 高要市| 红桥区| 广水市| 大同县| 滕州市| 东源县| 清原| 平和县| 监利县| 博爱县| 青神县| 永昌县| 吴忠市| 特克斯县|