Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 988ae5c

Browse files
LewuatheNikhil Thorat
authored andcommitted
Hyper parameters of optimizer can be set from UI (#90)
* Hyper parameters of optimizer can be set from UI - learning rate - momentum - batch size They are configurable from UI.
1 parent 34d7c8c commit 988ae5c

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

demos/model-builder/model-builder.html

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,17 @@
239239
</template>
240240
</paper-listbox>
241241
</paper-dropdown-menu>
242+
243+
<div class="subtitle">Hyperparameters</div>
244+
<paper-input no-animations label="Learning Rate" id="learning-rate-input" disabled="[[!datasetDownloaded]]" value={{learningRate}}>
245+
</paper-input>
246+
247+
<paper-input no-animations label="Momentum" id="momentum" disabled="[[!datasetDownloaded]]" value={{momentum}}>
248+
</paper-input>
249+
250+
<paper-input no-animations label="Batch Size" id="batch-size" disabled="[[!datasetDownloaded]]" value={{batchSize}}>
251+
</paper-input>
252+
242253
<div hidden$="[[isValid]]" class="model-error">
243254
<div hidden$="[[!datasetDownloaded]]"">
244255
<paper-tooltip animation-delay="0" fit-to-visible-bounds>

demos/model-builder/model-builder.ts

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ import {Normalization} from './tensorflow';
3131

3232
const DATASETS_CONFIG_JSON = 'model-builder-datasets-config.json';
3333

34-
// TODO(nsthorat): Make these parameters in the UI.
35-
const BATCH_SIZE = 64;
36-
const LEARNING_RATE = 0.1;
37-
const MOMENTUM = 0.1;
3834
/** How often to evaluate the model against test data. */
3935
const EVAL_INTERVAL_MS = 1500;
4036
/** How often to compute the cost. Downloading the cost stalls the GPU. */
@@ -74,6 +70,9 @@ export let ModelBuilderPolymer: new () => PolymerHTMLElement = PolymerElement({
7470
datasetNames: Array,
7571
selectedDatasetName: String,
7672
modelNames: Array,
73+
learningRate: Number,
74+
momentum: Number,
75+
batchSize: Number,
7776
selectedModelName: String,
7877
selectedNormalizationOption:
7978
{type: Number, value: Normalization.NORMALIZATION_NEGATIVE_ONE_TO_ONE},
@@ -122,6 +121,9 @@ export class ModelBuilder extends ModelBuilderPolymer {
122121
private dataSet: InMemoryDataset;
123122
private xhrDatasetConfigs: {[datasetName: string]: XhrDatasetConfig};
124123
private datasetStats: DataStats[];
124+
private learingRate: number;
125+
private momentum: number;
126+
private batchSize: number;
125127

126128
// Stats.
127129
private showDatasetStats: boolean;
@@ -183,7 +185,7 @@ export class ModelBuilder extends ModelBuilderPolymer {
183185
totalTimeSec.toFixed(1),
184186
};
185187
this.graphRunner = new GraphRunner(this.math, this.session, eventObserver);
186-
this.optimizer = new MomentumOptimizer(LEARNING_RATE, MOMENTUM);
188+
this.optimizer = new MomentumOptimizer(this.learingRate, this.momentum);
187189

188190
// Set up datasets.
189191
this.populateDatasets();
@@ -218,6 +220,9 @@ export class ModelBuilder extends ModelBuilderPolymer {
218220
this.setupDatasetStats();
219221
});
220222
}
223+
this.learningRate = 0.1;
224+
this.momentum = 0.1;
225+
this.batchSize = 64;
221226

222227
this.applicationState = ApplicationState.IDLE;
223228
this.loadedWeights = null;
@@ -318,6 +323,9 @@ export class ModelBuilder extends ModelBuilderPolymer {
318323
const trainingData = this.getTrainingData();
319324
const testData = this.getTestData();
320325

326+
// Recreate optimizer with the latest learning rate.
327+
this.optimizer = new MomentumOptimizer(+this.learningRate, +this.momentum);
328+
321329
if (this.isValid && (trainingData != null) && (testData != null)) {
322330
this.recreateCharts();
323331
this.graphRunner.resetStatistics();
@@ -343,9 +351,10 @@ export class ModelBuilder extends ModelBuilderPolymer {
343351
];
344352

345353
this.graphRunner.train(
346-
this.costTensor, trainFeeds, BATCH_SIZE, this.optimizer,
354+
this.costTensor, trainFeeds, this.batchSize, this.optimizer,
347355
undefined /** numBatches */, this.accuracyTensor, accuracyFeeds,
348-
BATCH_SIZE, MetricReduction.MEAN, EVAL_INTERVAL_MS, COST_INTERVAL_MS);
356+
this.batchSize, MetricReduction.MEAN, EVAL_INTERVAL_MS,
357+
COST_INTERVAL_MS);
349358

350359
this.showTrainStats = true;
351360
this.applicationState = ApplicationState.TRAINING;
@@ -628,7 +637,7 @@ export class ModelBuilder extends ModelBuilderPolymer {
628637
}
629638

630639
displayBatchesTrained(totalBatchesTrained: number) {
631-
this.examplesTrained = BATCH_SIZE * totalBatchesTrained;
640+
this.examplesTrained = this.batchSize * totalBatchesTrained;
632641
}
633642

634643
displayCost(avgCost: Scalar) {

0 commit comments

Comments
 (0)