@@ -31,10 +31,6 @@ import {Normalization} from './tensorflow';
3131
3232const DATASETS_CONFIG_JSON = 'model-builder-datasets-config.json' ;
3333
34- // TODO(nsthorat): Make these parameters in the UI.
35- const BATCH_SIZE = 64 ;
36- const LEARNING_RATE = 0.1 ;
37- const MOMENTUM = 0.1 ;
3834/** How often to evaluate the model against test data. */
3935const EVAL_INTERVAL_MS = 1500 ;
4036/** How often to compute the cost. Downloading the cost stalls the GPU. */
@@ -74,6 +70,9 @@ export let ModelBuilderPolymer: new () => PolymerHTMLElement = PolymerElement({
7470 datasetNames : Array ,
7571 selectedDatasetName : String ,
7672 modelNames : Array ,
73+ learningRate : Number ,
74+ momentum : Number ,
75+ batchSize : Number ,
7776 selectedModelName : String ,
7877 selectedNormalizationOption :
7978 { type : Number , value : Normalization . NORMALIZATION_NEGATIVE_ONE_TO_ONE } ,
@@ -122,6 +121,9 @@ export class ModelBuilder extends ModelBuilderPolymer {
122121 private dataSet : InMemoryDataset ;
123122 private xhrDatasetConfigs : { [ datasetName : string ] : XhrDatasetConfig } ;
124123 private datasetStats : DataStats [ ] ;
124+ private learingRate : number ;
125+ private momentum : number ;
126+ private batchSize : number ;
125127
126128 // Stats.
127129 private showDatasetStats : boolean ;
@@ -183,7 +185,7 @@ export class ModelBuilder extends ModelBuilderPolymer {
183185 totalTimeSec . toFixed ( 1 ) ,
184186 } ;
185187 this . graphRunner = new GraphRunner ( this . math , this . session , eventObserver ) ;
186- this . optimizer = new MomentumOptimizer ( LEARNING_RATE , MOMENTUM ) ;
188+ this . optimizer = new MomentumOptimizer ( this . learingRate , this . momentum ) ;
187189
188190 // Set up datasets.
189191 this . populateDatasets ( ) ;
@@ -218,6 +220,9 @@ export class ModelBuilder extends ModelBuilderPolymer {
218220 this . setupDatasetStats ( ) ;
219221 } ) ;
220222 }
223+ this . learningRate = 0.1 ;
224+ this . momentum = 0.1 ;
225+ this . batchSize = 64 ;
221226
222227 this . applicationState = ApplicationState . IDLE ;
223228 this . loadedWeights = null ;
@@ -318,6 +323,9 @@ export class ModelBuilder extends ModelBuilderPolymer {
318323 const trainingData = this . getTrainingData ( ) ;
319324 const testData = this . getTestData ( ) ;
320325
326+ // Recreate optimizer with the latest learning rate.
327+ this . optimizer = new MomentumOptimizer ( + this . learningRate , + this . momentum ) ;
328+
321329 if ( this . isValid && ( trainingData != null ) && ( testData != null ) ) {
322330 this . recreateCharts ( ) ;
323331 this . graphRunner . resetStatistics ( ) ;
@@ -343,9 +351,10 @@ export class ModelBuilder extends ModelBuilderPolymer {
343351 ] ;
344352
345353 this . graphRunner . train (
346- this . costTensor , trainFeeds , BATCH_SIZE , this . optimizer ,
354+ this . costTensor , trainFeeds , this . batchSize , this . optimizer ,
347355 undefined /** numBatches */ , this . accuracyTensor , accuracyFeeds ,
348- BATCH_SIZE , MetricReduction . MEAN , EVAL_INTERVAL_MS , COST_INTERVAL_MS ) ;
356+ this . batchSize , MetricReduction . MEAN , EVAL_INTERVAL_MS ,
357+ COST_INTERVAL_MS ) ;
349358
350359 this . showTrainStats = true ;
351360 this . applicationState = ApplicationState . TRAINING ;
@@ -628,7 +637,7 @@ export class ModelBuilder extends ModelBuilderPolymer {
628637 }
629638
630639 displayBatchesTrained ( totalBatchesTrained : number ) {
631- this . examplesTrained = BATCH_SIZE * totalBatchesTrained ;
640+ this . examplesTrained = this . batchSize * totalBatchesTrained ;
632641 }
633642
634643 displayCost ( avgCost : Scalar ) {
0 commit comments