Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit d52d023

Browse files
manrajgroverNikhil Thorat
authored andcommitted
Add Adam Optimizer with eager mode (#689)
1 parent 8326870 commit d52d023

File tree

5 files changed

+308
-15
lines changed

5 files changed

+308
-15
lines changed

src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ export {Graph, SymbolicTensor} from './graph/graph';
3838
export {GraphRunner, GraphRunnerEventObserver, MetricReduction} from './graph/graph_runner';
3939
// tslint:disable-next-line:max-line-length
4040
export {ConstantInitializer, Initializer, OnesInitializer, RandomNormalInitializer, RandomTruncatedNormalInitializer, RandomUniformInitializer, TensorInitializer, VarianceScalingInitializer, ZerosInitializer} from './graph/initializers';
41-
export {AdamOptimizer} from './graph/optimizers/adam_optimizer';
4241
export {AdamaxOptimizer} from './graph/optimizers/adamax_optimizer';
4342
export {CostReduction, FeedEntry, Session} from './graph/session';
4443
export {MathBackendCPU, NDArrayMathCPU} from './kernels/backend_cpu';
@@ -50,6 +49,7 @@ export {Model} from './model';
5049
export {LSTMCell} from './ops/lstm';
5150
export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
5251
export {AdagradOptimizer} from './optimizers/adagrad_optimizer';
52+
export {AdamOptimizer} from './optimizers/adam_optimizer';
5353
export {MomentumOptimizer} from './optimizers/momentum_optimizer';
5454
export {Optimizer} from './optimizers/optimizer';
5555
export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';

src/optimizers/adam_optimizer.ts

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {ENV} from '../environment';
19+
import {keep, tidy} from '../globals';
20+
import {Node} from '../graph/graph';
21+
import {SessionRuntime} from '../graph/session';
22+
// tslint:disable-next-line:max-line-length
23+
import {SummedTensorArrayMap, TensorArrayMap} from '../graph/tensor_array_map';
24+
import {NDArrayMath} from '../math';
25+
import {scalar, zerosLike} from '../ops/ops';
26+
import {Scalar, Tensor, Variable} from '../tensor';
27+
import {variable} from '../tensor';
28+
import {NamedVariableMap} from '../types';
29+
30+
import {Optimizer} from './optimizer';
31+
32+
export class AdamOptimizer extends Optimizer {
33+
private c: Scalar;
34+
private eps: Scalar;
35+
private beta1: Scalar;
36+
private beta2: Scalar;
37+
private accBeta1: Variable;
38+
private accBeta2: Variable;
39+
private oneMinusBeta1: Scalar;
40+
private oneMinusBeta2: Scalar;
41+
private one: Scalar;
42+
43+
private accumulatedFirstMoment: NamedVariableMap = {};
44+
private accumulatedSecondMoment: NamedVariableMap = {};
45+
46+
constructor(
47+
protected learningRate: number, beta1: number, beta2: number,
48+
epsilon = 1e-8, specifiedVariableList?: Node[]) {
49+
super(learningRate, specifiedVariableList);
50+
this.c = keep(scalar(-learningRate));
51+
this.eps = keep(scalar(epsilon));
52+
// b1, b2 keep initial value of beta* hyperparameters.
53+
this.beta1 = keep(scalar(beta1));
54+
this.beta2 = keep(scalar(beta2));
55+
// accB* will be updated by batch.
56+
this.accBeta1 = variable(scalar(beta1));
57+
this.accBeta2 = variable(scalar(beta2));
58+
this.oneMinusBeta1 = keep(scalar(1 - beta1));
59+
this.oneMinusBeta2 = keep(scalar(1 - beta2));
60+
this.one = keep(scalar(1));
61+
}
62+
63+
applyGradients(variableGradients: NamedVariableMap) {
64+
tidy(() => {
65+
const oneMinusAccBeta1 = this.one.sub(this.accBeta1);
66+
const oneMinusAccBeta2 = this.one.sub(this.accBeta2);
67+
68+
for (const variableName in variableGradients) {
69+
const value = ENV.engine.registeredVariables[variableName];
70+
if (this.accumulatedFirstMoment[variableName] == null) {
71+
const trainable = false;
72+
this.accumulatedFirstMoment[variableName] =
73+
variable(zerosLike(value), trainable);
74+
}
75+
if (this.accumulatedSecondMoment[variableName] == null) {
76+
const trainable = false;
77+
this.accumulatedSecondMoment[variableName] =
78+
variable(zerosLike(value), trainable);
79+
}
80+
81+
const gradient = variableGradients[variableName];
82+
const firstMoment = this.accumulatedFirstMoment[variableName];
83+
const secondMoment = this.accumulatedSecondMoment[variableName];
84+
85+
const newFirstMoment =
86+
this.beta1.mul(firstMoment).add(this.oneMinusBeta1.mul(gradient));
87+
const newSecondMoment =
88+
this.beta2.mul(secondMoment)
89+
.add(this.oneMinusBeta2.mul(gradient.square()));
90+
91+
const biasCorrectedFirstMoment = newFirstMoment.div(oneMinusAccBeta1);
92+
const biasCorrectedSecondMoment = newSecondMoment.div(oneMinusAccBeta2);
93+
94+
this.accumulatedFirstMoment[variableName].assign(newFirstMoment);
95+
this.accumulatedSecondMoment[variableName].assign(newSecondMoment);
96+
97+
const newValue = this.c
98+
.mul(biasCorrectedFirstMoment.div(this.eps.add(
99+
biasCorrectedSecondMoment.sqrt())))
100+
.add(value);
101+
value.assign(newValue);
102+
}
103+
104+
this.accBeta1.assign(this.accBeta1.mul(this.beta1));
105+
this.accBeta2.assign(this.accBeta2.mul(this.beta2));
106+
});
107+
}
108+
109+
beforeBatch(
110+
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
111+
activationArrayMap: TensorArrayMap,
112+
gradientArrayMap: SummedTensorArrayMap) {
113+
super.beforeBatch(
114+
math, batchSize, runtime, activationArrayMap, gradientArrayMap);
115+
116+
if (this.firstMomentGraph.size() === 0) {
117+
this.variableNodes.forEach(node => {
118+
this.firstMomentGraph.set(node.output, Tensor.zeros(node.output.shape));
119+
});
120+
}
121+
122+
if (this.secondMomentGraph.size() === 0) {
123+
this.variableNodes.forEach(node => {
124+
this.secondMomentGraph.set(
125+
node.output, Tensor.zeros(node.output.shape));
126+
});
127+
}
128+
}
129+
130+
afterBatch(
131+
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
132+
activationArrayMap: TensorArrayMap,
133+
gradientArrayMap: SummedTensorArrayMap) {
134+
tidy(() => {
135+
const oneMinusAccBeta1 = this.one.sub(this.accBeta1);
136+
const oneMinusAccBeta2 = this.one.sub(this.accBeta2);
137+
138+
this.variableNodes.forEach(node => {
139+
const oldVariable = activationArrayMap.get(node.output);
140+
const gradient = this.variableGradients.get(node.output);
141+
142+
const oldFirstMoment = this.firstMomentGraph.get(node.output);
143+
const oldSecondMoment = this.secondMomentGraph.get(node.output);
144+
145+
const newFirstMoment = math.scaledArrayAdd(
146+
this.beta1, oldFirstMoment, this.oneMinusBeta1, gradient);
147+
const newSecondMoment = math.scaledArrayAdd(
148+
this.beta2, oldSecondMoment, this.oneMinusBeta2, gradient.square());
149+
150+
const biasCorrectedFirstMoment = newFirstMoment.div(oneMinusAccBeta1);
151+
const biasCorrectedSecondMoment = newSecondMoment.div(oneMinusAccBeta2);
152+
const variable = math.scaledArrayAdd(
153+
this.cGraph,
154+
biasCorrectedFirstMoment.div(
155+
this.eps.add(biasCorrectedSecondMoment.sqrt())),
156+
this.one, oldVariable);
157+
activationArrayMap.set(node.output, keep(variable));
158+
node.data = variable;
159+
160+
this.firstMomentGraph.set(node.output, keep(newFirstMoment));
161+
this.secondMomentGraph.set(node.output, keep(newSecondMoment));
162+
163+
oldVariable.dispose();
164+
gradient.dispose();
165+
oldFirstMoment.dispose();
166+
oldSecondMoment.dispose();
167+
});
168+
this.accBeta1.assign(this.accBeta1.mul(this.beta1));
169+
this.accBeta2.assign(this.accBeta2.mul(this.beta2));
170+
});
171+
172+
this.variableGradients.dispose();
173+
this.variableGradients = new TensorArrayMap();
174+
}
175+
176+
dispose() {
177+
super.dispose();
178+
this.c.dispose();
179+
this.eps.dispose();
180+
this.beta1.dispose();
181+
this.beta2.dispose();
182+
this.accBeta1.dispose();
183+
this.accBeta2.dispose();
184+
this.oneMinusBeta1.dispose();
185+
this.oneMinusBeta2.dispose();
186+
this.one.dispose();
187+
188+
if (this.firstMomentGraph != null) {
189+
this.firstMomentGraph.dispose();
190+
}
191+
192+
if (this.secondMomentGraph != null) {
193+
this.secondMomentGraph.dispose();
194+
}
195+
196+
if (this.accumulatedFirstMoment != null) {
197+
Object.keys(this.accumulatedFirstMoment)
198+
.forEach(name => this.accumulatedFirstMoment[name].dispose());
199+
}
200+
201+
if (this.accumulatedSecondMoment != null) {
202+
Object.keys(this.accumulatedSecondMoment)
203+
.forEach(name => this.accumulatedSecondMoment[name].dispose());
204+
}
205+
}
206+
207+
// Average of gradient
208+
private firstMomentGraph = new TensorArrayMap();
209+
// Average of squared gradient
210+
private secondMomentGraph = new TensorArrayMap();
211+
}

src/graph/optimizers/adam_optimizer_test.ts renamed to src/optimizers/adam_optimizer_test.ts

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,96 @@
1414
* limitations under the License.
1515
* =============================================================================
1616
*/
17-
import {InputProvider} from '../../data/input_provider';
18-
import {ENV} from '../../environment';
19-
import * as dl from '../../index';
20-
import {Tensor1D} from '../../tensor';
21-
import * as test_util from '../../test_util';
22-
import {Graph} from '../graph';
23-
import {Session} from '../session';
17+
import {InputProvider} from '../data/input_provider';
18+
import {ENV} from '../environment';
19+
import {Graph} from '../graph/graph';
20+
import {Session} from '../graph/session';
21+
import * as dl from '../index';
22+
import {ALL_ENVS, describeWithFlags, expectArraysClose} from '../test_util';
2423
import {AdamOptimizer} from './adam_optimizer';
2524

26-
describe('adam optimizer', () => {
25+
describeWithFlags('AdamOptimizer', ALL_ENVS, () => {
2726
it('basic', () => {
27+
const learningRate = .1;
28+
const beta1 = .8;
29+
const beta2 = .9;
30+
const optimizer = dl.train.adam(learningRate, beta1, beta2);
31+
32+
const x = dl.variable(dl.tensor1d([2, 4]));
33+
34+
const f = () => x.square().sum() as dl.Scalar;
35+
36+
let numTensors = dl.memory().numTensors;
37+
38+
let cost = optimizer.minimize(f, /* returnCost */ true);
39+
40+
// Cost & 2 accumulators should be the only additional arrays.
41+
expect(dl.memory().numTensors).toBe(numTensors + 3);
42+
// new_first_m = [
43+
// beta1 * old_first_m_w1 + (1-beta1) * grad_w1,
44+
// beta1 * old_first_m_w2 + (1-beta1) * grad_w2
45+
// ] = [.8, 1.6]
46+
// new_second_m = [
47+
// beta2 * old_second_m_w1 + (1-beta2) * grad_w1**2,
48+
// beta2 * old_second_m_w2 + (1-beta2) * grad_w2**2
49+
// ] = [1.6, 6.4]
50+
// m = [new_first_m/(1-acc_beta1)] = [4, 8]
51+
// v = [new_second_m/(1-acc_beta2)] = [16, 64]
52+
// x = [x - lr * m / sqrt(v)] = [1.9, 3.9]
53+
//
54+
expectArraysClose(x, [1.9, 3.9]);
55+
56+
cost.dispose();
57+
numTensors = dl.memory().numTensors;
58+
59+
cost = optimizer.minimize(f, /* returnCost */ false);
60+
61+
// new_first_m = [
62+
// beta1 * old_first_m_w1 + (1-beta1) * grad_w1,
63+
// beta1 * old_first_m_w2 + (1-beta1) * grad_w2
64+
// ] = [1.4, 2.84]
65+
// new_second_m = [
66+
// beta2 * old_second_m_w1 + (1-beta2) * grad_w1**2,
67+
// beta2 * old_second_m_w2 + (1-beta2) * grad_w2**2
68+
// ] = [2.884, 11.884]
69+
// m = [new_first_m/(1-acc_beta1)] = [3.888888, 7.88889]
70+
// v = [new_second_m/(1-acc_beta2)] = [15.1789, 62.5473]
71+
// x = [x - lr * m / sqrt(v)] = [1.8000001, 3.8002]
72+
//
73+
expectArraysClose(x, [1.8000001, 3.8002]);
74+
// There should be no new additional Tensors.
75+
expect(dl.memory().numTensors).toBe(numTensors);
76+
77+
expect(cost).toBe(null);
78+
79+
x.dispose();
80+
optimizer.dispose();
81+
82+
// There should be no more Tensors.
83+
expect(dl.memory().numTensors).toBe(0);
84+
});
85+
86+
it('graph', () => {
2887
const math = ENV.math;
2988

3089
const inputProvider: InputProvider = {
3190
getNextCopy() {
32-
return Tensor1D.new([2, 4]);
91+
return dl.tensor1d([2, 4]);
3392
},
3493
disposeCopy(example) {}
3594
};
3695

3796
dl.tidy(() => {
97+
const learningRate = .1;
98+
const beta1 = .8;
99+
const beta2 = .9;
100+
38101
const g = new Graph();
39102
const x = g.placeholder('x', [2]);
40103
const w = g.variable('w', dl.zeros([1, 2]));
41104
const b = g.variable('b', dl.zeros([1]));
42105
const y = g.reduceSum(g.add(g.matmul(w, x), b));
43-
const optimizer = new AdamOptimizer(0.1, 0.8, 0.9);
106+
const optimizer = new AdamOptimizer(learningRate, beta1, beta2);
44107
const session = new Session(g, math);
45108
// w = reduce_sum(w_1*x_1 + w_2*x_2 + b)
46109
// new_first_m = [beta1*old_first_m_w1 + (1-beta1)*grad_w1,
@@ -59,7 +122,7 @@ describe('adam optimizer', () => {
59122
//
60123
session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer);
61124
const dydw = session.activationArrayMap.get(w).dataSync();
62-
test_util.expectArraysClose(dydw, new Float32Array([-0.1, -0.1]), 1e-5);
125+
expectArraysClose(dydw, new Float32Array([-0.1, -0.1]), 1e-2);
63126
// new_first_m = [beta1*old_first_m_w1 + (1-beta1)*grad_w1,
64127
// beta1*old_first_m_w2 + (1-beta1)*grad_w2]
65128
// = [0.8*0.4 + 0.2*2, 0.8*0.8 + 0.2*4]
@@ -77,7 +140,7 @@ describe('adam optimizer', () => {
77140
// = [-0.2, -0.2]
78141
session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer);
79142
const dydw2 = session.activationArrayMap.get(w).dataSync();
80-
test_util.expectArraysClose(dydw2, new Float32Array([-.2, -.2]), 2e-5);
143+
expectArraysClose(dydw2, new Float32Array([-.2, -.2]), 1e-2);
81144
});
82145
});
83146
});

src/optimizers/optimizer_constructors.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import {doc} from '../doc';
1919
import {AdadeltaOptimizer} from './adadelta_optimizer';
2020
import {AdagradOptimizer} from './adagrad_optimizer';
21+
import {AdamOptimizer} from './adam_optimizer';
2122
import {MomentumOptimizer} from './momentum_optimizer';
2223
import {RMSPropOptimizer} from './rmsprop_optimizer';
2324
import {SGDOptimizer} from './sgd_optimizer';
@@ -71,6 +72,22 @@ export class OptimizerConstructors {
7172
undefined /** @deprecated specifiedVariableList */, epsilon);
7273
}
7374

75+
/**
76+
* Constructs a `AdamOptimizer` that uses the Adam algorithm.
77+
* See https://arxiv.org/abs/1412.6980
78+
*
79+
* @param learningRate
80+
* @param beta1
81+
* @param beta2
82+
*/
83+
@doc({heading: 'Training', subheading: 'Optimizers', namespace: 'train'})
84+
static adam(learningRate = 0.001, beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8):
85+
AdamOptimizer {
86+
return new AdamOptimizer(
87+
learningRate, beta1, beta2, epsilon,
88+
undefined /** @deprecated specifiedVariableList */);
89+
}
90+
7491
/**
7592
* Constructs a `AdadeltaOptimizer` that uses the Adadelta algorithm.
7693
* See https://arxiv.org/abs/1212.5701

0 commit comments

Comments
 (0)