Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 8326870

Browse files
author
Nikhil Thorat
authored
Don't dispose input tensor when calling dl.variable(). (#712)
1 parent c25b259 commit 8326870

12 files changed

+61
-32
lines changed

src/optimizers/adadelta_optimizer.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,17 @@ export class AdadeltaOptimizer extends Optimizer {
5555
const value = ENV.engine.registeredVariables[variableName];
5656
if (this.accumulatedGrads[variableName] == null) {
5757
const trainable = false;
58-
this.accumulatedGrads[variableName] =
59-
variable(zerosLike(value), trainable);
58+
tidy(() => {
59+
this.accumulatedGrads[variableName] =
60+
variable(zerosLike(value), trainable);
61+
});
6062
}
6163
if (this.accumulatedUpdates[variableName] == null) {
6264
const trainable = false;
63-
this.accumulatedUpdates[variableName] =
64-
variable(zerosLike(value), trainable);
65+
tidy(() => {
66+
this.accumulatedUpdates[variableName] =
67+
variable(zerosLike(value), trainable);
68+
});
6569
}
6670

6771
const gradient = variableGradients[variableName];

src/optimizers/adadelta_optimizer_test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ describeWithFlags('AdadeltaOptimizer', ALL_ENVS, () => {
7676
x.dispose();
7777
optimizer.dispose();
7878

79-
// There should be no more Tensors.
80-
expect(dl.memory().numTensors).toBe(0);
79+
// The only tensor remaining is the argument to variable().
80+
expect(dl.memory().numTensors).toBe(1);
8181
});
8282

8383
it('graph', () => {

src/optimizers/adagrad_optimizer.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ export class AdagradOptimizer extends Optimizer {
5050
const value = ENV.engine.registeredVariables[variableName];
5151
if (this.accumulatedGrads[variableName] == null) {
5252
const trainable = false;
53-
this.accumulatedGrads[variableName] = variable(
54-
fill(value.shape, this.initialAccumulatorValue), trainable);
53+
tidy(() => {
54+
this.accumulatedGrads[variableName] = variable(
55+
fill(value.shape, this.initialAccumulatorValue), trainable);
56+
});
5557
}
5658

5759
const gradient = variableGradients[variableName];

src/optimizers/adagrad_optimizer_test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ describeWithFlags('AdagradOptimizer', ALL_ENVS, () => {
6868
x.dispose();
6969
optimizer.dispose();
7070

71-
// There should be no more Tensors.
72-
expect(dl.memory().numTensors).toBe(0);
71+
// The only tensor remaining is the argument to variable().
72+
expect(dl.memory().numTensors).toBe(1);
7373
});
7474

7575
it('graph', () => {

src/optimizers/momentum_optimizer.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ export class MomentumOptimizer extends SGDOptimizer {
4646
const value = ENV.engine.registeredVariables[variableName];
4747
if (this.accumulations[variableName] == null) {
4848
const trainable = false;
49-
this.accumulations[variableName] =
50-
variable(zerosLike(value), trainable);
49+
tidy(() => {
50+
this.accumulations[variableName] =
51+
variable(zerosLike(value), trainable);
52+
});
5153
}
5254

5355
const accumulation = this.accumulations[variableName];

src/optimizers/momentum_optimizer_test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ describeWithFlags('MomentumOptimizer', ALL_ENVS, () => {
6767
x.dispose();
6868
optimizer.dispose();
6969

70-
// There should be no more Tensors.
71-
expect(dl.memory().numTensors).toBe(0);
70+
// The only tensor remaining is the argument to variable().
71+
expect(dl.memory().numTensors).toBe(1);
7272
});
7373

7474
it('graph', () => {

src/optimizers/optimizer_test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ describeWithFlags('optimizer', ALL_ENVS, () => {
6868
x.dispose();
6969
bias.dispose();
7070
strayVariable.dispose();
71-
// There should be no more Tensors.
72-
expect(dl.memory().numTensors).toBe(0);
71+
// The only tensors remaining are the arguments to variable().
72+
expect(dl.memory().numTensors).toBe(3);
7373
});
7474

7575
it('varList array of all variables', () => {

src/optimizers/rmsprop_optimizer.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,17 @@ export class RMSPropOptimizer extends Optimizer {
5858
const value = ENV.engine.registeredVariables[variableName];
5959
if (this.accumulatedMeanSquares[variableName] == null) {
6060
const trainable = false;
61-
this.accumulatedMeanSquares[variableName] =
62-
variable(zerosLike(value), trainable);
61+
tidy(() => {
62+
this.accumulatedMeanSquares[variableName] =
63+
variable(zerosLike(value), trainable);
64+
});
6365
}
6466
if (this.accumulatedMoments[variableName] == null) {
6567
const trainable = false;
66-
this.accumulatedMoments[variableName] =
67-
variable(zerosLike(value), trainable);
68+
tidy(() => {
69+
this.accumulatedMoments[variableName] =
70+
variable(zerosLike(value), trainable);
71+
});
6872
}
6973

7074
const accumulatedMeanSquare = this.accumulatedMeanSquares[variableName];

src/optimizers/rmsprop_optimizer_test.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ describeWithFlags('RMSPropOptimizer', ALL_ENVS, () => {
7878

7979
x.dispose();
8080
optimizer.dispose();
81-
expect(dl.memory().numTensors).toBe(0);
81+
// The only tensor remaining is the argument to variable().
82+
expect(dl.memory().numTensors).toBe(1);
8283
});
8384

8485
it('graph', () => {

src/optimizers/sgd_optimizer_test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ describeWithFlags('SGDOptimizer', ALL_ENVS, () => {
5454

5555
optimizer.dispose();
5656
x.dispose();
57-
// There should be no more Tensors.
58-
expect(dl.memory().numTensors).toBe(0);
57+
// The only tensor remaining is the argument to variable().
58+
expect(dl.memory().numTensors).toBe(1);
5959
});
6060

6161
it('graph', () => {

0 commit comments

Comments
 (0)