44import android .graphics .Canvas ;
55import android .graphics .Color ;
66import android .graphics .Paint ;
7- import android .os .AsyncTask ;
87import android .util .AttributeSet ;
98import android .view .View ;
109
1110import androidx .annotation .Nullable ;
1211
13- import org .deeplearning4j .nn .conf .MultiLayerConfiguration ;
14- import org .deeplearning4j .nn .conf .NeuralNetConfiguration ;
15- import org .deeplearning4j .nn .conf .layers .DenseLayer ;
16- import org .deeplearning4j .nn .conf .layers .OutputLayer ;
17- import org .deeplearning4j .nn .multilayer .MultiLayerNetwork ;
18- import org .deeplearning4j .nn .weights .WeightInit ;
19- import org .deeplearning4j .optimize .listeners .ScoreIterationListener ;
20- import org .nd4j .evaluation .classification .Evaluation ;
21- import org .nd4j .linalg .activations .Activation ;
22- import org .nd4j .linalg .api .buffer .DataType ;
2312import org .nd4j .linalg .api .ndarray .INDArray ;
24- import org .nd4j .linalg .dataset .DataSet ;
25- import org .nd4j .linalg .factory .Nd4j ;
26- import org .nd4j .linalg .learning .config .Nesterovs ;
27- import org .nd4j .linalg .lossfunctions .LossFunctions ;
2813
29- import java .io .BufferedReader ;
30- import java .io .IOException ;
31- import java .io .InputStreamReader ;
32- import java .util .ArrayList ;
33-
34- public class ScatterView extends View {
14+ public class ScatterView extends View implements OnTrainingUpdateEventListener {
3515
3616 private final Paint redPaint ;
3717 private final Paint greenPaint ;
3818 private final Paint lightGreenPaint ;
3919 private final Paint lightRedPaint ;
40- private float [][] data ;
41- private DataSet ds ;
4220
43- private final int nPointsPerAxis = 100 ;
44- private INDArray xyGrid ; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
45- private INDArray modelOut = null ;
21+
22+ private INDArray modelOut = null ; // nn output for the grid.
23+
24+ private TrainingTask task ;
4625
4726 public ScatterView (Context context , @ Nullable AttributeSet attrs ) {
4827 super (context , attrs );
49- data = null ;
5028 redPaint = new Paint ();
5129 redPaint .setColor (Color .RED );
5230 greenPaint = new Paint ();
@@ -57,16 +35,9 @@ public ScatterView(Context context, @Nullable AttributeSet attrs) {
5735 lightRedPaint = new Paint ();
5836 lightRedPaint .setColor (Color .rgb (255 , 153 , 152 ));
5937
60- AsyncTask .execute (() -> {
61- try {
62- calcGrid ();
63- ReadCSV ();
64- BuildNN ();
65-
66- } catch (IOException e ) {
67- e .printStackTrace ();
68- }
69- });
38+ task = new TrainingTask ();
39+ task .setListener (this );
40+ task .executeTask ();
7041 }
7142
7243 @ Override
@@ -75,9 +46,10 @@ public void onDraw(Canvas canvas) {
7546 int w = this .getWidth ();
7647
7748 //draw the nn predictions:
78- if ((modelOut != null ) && (null != xyGrid )){
79- int halfRectHeight = h / nPointsPerAxis ;
80- int halfRectWidth = w / nPointsPerAxis ;
49+ if (modelOut != null ) {
50+ int halfRectHeight = h / task .getnPointsPerAxis ();
51+ int halfRectWidth = w / task .getnPointsPerAxis ();
52+ INDArray xyGrid = task .getXyGrid ();
8153 int nRows = xyGrid .rows ();
8254
8355 for (int i = 0 ; i < nRows ; i ++){
@@ -91,9 +63,9 @@ public void onDraw(Canvas canvas) {
9163 }
9264
9365 //draw the data set if we have one.
94- if (null != data ) {
66+ if (null != task . getData () ) {
9567
96- for (float [] datum : data ) {
68+ for (float [] datum : task . getData () ) {
9769 int x = (int ) (datum [1 ] * w );
9870 int y = (int ) (datum [2 ] * h );
9971 Paint p = (datum [0 ] == 0.0f ) ? redPaint : greenPaint ;
@@ -102,124 +74,11 @@ public void onDraw(Canvas canvas) {
10274 }
10375 }
10476
105- /**
106- * this is not the regular way to read a csv file into a data set with DL4j.
107- * In this example we have put the data in the assets folder so that the demo works offline.
108- */
109- private void ReadCSV () throws IOException {
110- InputStreamReader is = new InputStreamReader (MainActivity .getInstance ().getApplicationContext ().getAssets ()
111- .open ("linear_data_train.csv" ));
112-
113- BufferedReader reader = new BufferedReader (is );
114- ArrayList <String > rawSVC = new ArrayList <>();
115- String line ;
116- while ((line = reader .readLine ()) != null ) {
117- rawSVC .add (line );
118- }
119-
120- float [][] tmpData = new float [rawSVC .size ()][3 ];
121-
122- int index = 0 ;
123- for (String l : rawSVC ){
124- String [] values = l .split ("," );
125- for (int col = 0 ; col < 3L ; col ++){
126- tmpData [index ][col ] = Float .parseFloat (values [col ]);
127- }
128-
129- index ++;
130- }
131-
132- normalizeColumn (1 , tmpData );
133- normalizeColumn (2 , tmpData );
134-
135- this .data = tmpData ;
136- INDArray arrData = Nd4j .createFromArray (tmpData );
137- INDArray arrFeatures = arrData .getColumns (1 , 2 );
138- INDArray c1 = arrData .getColumns (0 );
139- INDArray c2 = c1 .mul (-1 ).addi (1.0 );
140- INDArray labels = Nd4j .hstack (c1 , c2 );
141- ds = new DataSet (arrFeatures , labels );
142- }
143-
144- /**
145- * Normalize the data in a given column. Normally one would use datavec.
146- * @param c column to normalise.
147- * @param tmpData java float array.
148- */
149- private void normalizeColumn (int c , float [][] tmpData ){
150- int numPoints = tmpData .length ;
151- float min = tmpData [0 ][c ];
152- float max = tmpData [0 ][c ];
153- for (float [] tmpDatum : tmpData ) {
154- float x = tmpDatum [c ];
155- if (x < min ) {
156- min = x ;
157- }
158- if (x > max ) {
159- max = x ;
160- }
161- }
162-
163- for (int i =0 ; i <numPoints ; i ++){
164- float x = tmpData [i ][c ];
165- tmpData [i ][c ] = (x - min ) / (max - min );
166- }
167- }
168-
169- private void BuildNN (){
170- int seed = 123 ;
171- double learningRate = 0.005 ;
172- int numInputs = 2 ;
173- int numOutputs = 2 ;
174- int numHiddenNodes = 20 ;
175- int nEpochs = 2000 ;
176-
177- MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
178- .seed (seed )
179- .weightInit (WeightInit .XAVIER )
180- .updater (new Nesterovs (learningRate , 0.9 ))
181- .list ()
182- .layer (new DenseLayer .Builder ().nIn (numInputs ).nOut (numHiddenNodes )
183- .activation (Activation .RELU )
184- .build ())
185- .layer (new OutputLayer .Builder (LossFunctions .LossFunction .NEGATIVELOGLIKELIHOOD )
186- .activation (Activation .SOFTMAX )
187- .nIn (numHiddenNodes ).nOut (numOutputs ).build ())
188- .build ();
189-
190- MultiLayerNetwork model = new MultiLayerNetwork (conf );
191- model .init ();
192- model .setListeners (new ScoreIterationListener (10 ));
193-
194- for (int i = 0 ; i <nEpochs ; i ++){
195- model .fit (ds );
196- INDArray tmp = model .output (xyGrid );
197-
198- this .post (() -> {
199- this .modelOut = tmp ; // update from within the UI thread.
200- this .invalidate (); // have the UI thread redraw at its own convenience.
201- });
202- }
203-
204- Evaluation eval = new Evaluation (numOutputs );
205- INDArray features = ds .getFeatures ();
206- INDArray labels = ds .getLabels ();
207- INDArray predicted = model .output (features ,false );
208- eval .eval (labels , predicted );
209- System .out .println (eval .stats ());
210-
211- this .invalidate ();
212- }
213- /**
214- * The x,y grid to calculate the NN output. Only needs to be calculated once.
215- */
216- private void calcGrid (){
217- // x coordinates of the pixels for the NN.
218- INDArray xPixels = Nd4j .linspace (0 , 1.0 , nPointsPerAxis , DataType .DOUBLE );
219- // y coordinates of the pixels for the NN.
220- INDArray yPixels = Nd4j .linspace (0 , 1.0 , nPointsPerAxis , DataType .DOUBLE );
221- //create the mesh:
222- INDArray [] mesh = Nd4j .meshgrid (xPixels , yPixels );
223- xyGrid = Nd4j .vstack (mesh [0 ].ravel (), mesh [1 ].ravel ()).transpose ();
77+ @ Override
78+ public void OnTrainingUpdate (INDArray modelOut ) {
79+ this .post (() -> {
80+ this .modelOut = modelOut ; // update from within the UI thread.
81+ this .invalidate (); // have the UI thread redraw at its own convenience.
82+ });
22483 }
22584}
0 commit comments