Skip to content

Commit 2f020cb

Browse files
author
Arthur Douillard
committed
[common] Add support for several run seeds.
to squash
1 parent 22d0b26 commit 2f020cb

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

inclearn/__main__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,7 @@
44
args = parser.get_parser().parse_args()
55
args = vars(args) # Converting argparse Namespace to a dict.
66

7+
if args["seed_range"] is not None:
8+
args["seed"] = list(range(args["seed_range"][0], args["seed_range"][1] + 1))
9+
710
train(args)

inclearn/parser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def get_parser():
5656
help="GPU index to use, for cpu use -1.")
5757
parser.add_argument("--name", default="exp",
5858
help="Experience name")
59-
parser.add_argument("-seed", "--seed", default=1, type=int,
59+
parser.add_argument("-seed", "--seed", default=[1], type=int, nargs="+",
6060
help="Random seed.")
61+
parser.add_argument("-seed-range", "--seed-range", type=int, nargs=2,
62+
help="Seed range going from first number to second (both included).")
6163

6264
return parser

inclearn/train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import random
23

34
import numpy as np
@@ -7,6 +8,13 @@
78

89

910
def train(args):
11+
seed_list = copy.deepcopy(args["seed"])
12+
for seed in seed_list:
13+
args["seed"] = seed
14+
_train(args)
15+
16+
17+
def _train(args):
1018
_set_seed(args["seed"])
1119

1220
factory.set_device(args)
@@ -16,6 +24,7 @@ def train(args):
1624

1725
train_loader, val_loader = train_set.get_loader(args["validation"])
1826
test_loader, _ = test_set.get_loader()
27+
#val_loader = test_loader
1928

2029
model = factory.get_model(args)
2130

@@ -26,7 +35,6 @@ def train(args):
2635
break
2736

2837
# Setting current task's classes:
29-
3038
train_set.set_classes_range(low=task * args["increment"],
3139
high=(task + 1) * args["increment"])
3240
test_set.set_classes_range(high=(task + 1) * args["increment"])
@@ -40,10 +48,10 @@ def train(args):
4048
)
4149

4250
model.before_task(train_loader, val_loader)
51+
print("train", task * args["increment"], (task + 1) * args["increment"])
4352
model.train_task(train_loader, val_loader)
4453
model.after_task(train_loader)
4554

46-
print(test_loader.dataset._low_range, test_loader.dataset._high_range)
4755
ypred, ytrue = model.eval_task(test_loader)
4856
acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"])
4957
print(acc_stats)

0 commit comments

Comments
 (0)