Skip to content

Commit bc0434b

Browse files
author
Arthur Douillard
committed
[e2e] Add first working draft of 'End-to-End Incremental Learning'.
1 parent 0d8e5d6 commit bc0434b

File tree

3 files changed

+336
-8
lines changed

3 files changed

+336
-8
lines changed

inclearn/factory.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@ def get_resnet(resnet_type, **kwargs):
2525
def get_model(args):
2626
if args["model"] == "icarl":
2727
return models.ICarl(args)
28+
elif args["model"] == "lwf":
29+
return models.LwF(args)
30+
elif args["model"] == "e2e":
31+
return models.End2End(args)
2832

29-
raise NotImplementedError(arg["model"])
33+
raise NotImplementedError(args["model"])
3034

3135

3236
def get_data(args, train=True, classes_order=None):
@@ -39,12 +43,11 @@ def get_data(args, train=True, classes_order=None):
3943
else:
4044
raise NotImplementedError(dataset_name)
4145

42-
return dataset(
43-
increment=args["increment"],
44-
train=train,
45-
randomize_class=args["random_classes"],
46-
classes_order=classes_order
47-
)
46+
return dataset(increment=args["increment"],
47+
train=train,
48+
randomize_class=args["random_classes"],
49+
classes_order=classes_order)
50+
4851

4952
def set_device(args):
5053
device_type = args["device"]

inclearn/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
from .e2e import End2End
12
from .icarl import ICarl
23
from .lwf import LwF
34

4-
__all__ = ["ICarl", "LwF"]
5+
__all__ = ["ICarl", "LwF", "End2End"]

inclearn/models/e2e.py

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
import numpy as np
2+
import torch
3+
from torch import nn
4+
from torch.nn import functional as F
5+
from tqdm import trange
6+
7+
from inclearn import factory, utils
8+
from inclearn.models.base import IncrementalLearner
9+
10+
11+
class End2End(IncrementalLearner):
12+
"""Implementation of End-to-End Increment Learning.
13+
14+
:param args: An argparse parsed arguments object.
15+
"""
16+
17+
def __init__(self, args):
18+
super().__init__()
19+
20+
self._device = args["device"]
21+
self._opt_name = args["optimizer"]
22+
self._lr = args["lr"]
23+
self._weight_decay = args["weight_decay"]
24+
self._n_epochs = args["epochs"]
25+
26+
self._scheduling = args["scheduling"]
27+
self._lr_decay = args["lr_decay"]
28+
29+
self._k = args["memory_size"]
30+
self._n_classes = args["increment"]
31+
32+
self._temperature = args["temperature"]
33+
34+
self._features_extractor = factory.get_resnet(
35+
args["convnet"], nf=64, zero_init_residual=True
36+
)
37+
self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes, bias=False)
38+
torch.nn.init.kaiming_normal_(self._classifier.weight)
39+
40+
self._examplars = {}
41+
self._means = None
42+
43+
self.to(self._device)
44+
45+
def forward(self, x):
46+
x = self._features_extractor(x)
47+
x = self._classifier(x)
48+
return x
49+
50+
# ----------
51+
# Public API
52+
# ----------
53+
54+
def _before_task(self, train_loader, val_loader):
55+
"""Set up before the task training can begin.
56+
57+
1. Precomputes previous model probabilities.
58+
2. Extend the classifier to support new classes.
59+
60+
:param train_loader: The training dataloader.
61+
:param val_loader: The validation dataloader.
62+
"""
63+
if self._task == 0:
64+
self._previous_preds = None
65+
else:
66+
print("Computing previous predictions...")
67+
self._previous_preds = self._compute_predictions(train_loader)
68+
if val_loader:
69+
self._previous_preds_val = self._compute_predictions(val_loader)
70+
71+
self._add_n_classes(self._task_size)
72+
73+
def _train_task(self, train_loader, val_loader):
74+
# Training on all new + examplars
75+
self.foo = 0
76+
optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.1, 0.0001)
77+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20, 30], gamma=0.1)
78+
self._train(train_loader, 1, optimizer, scheduler)
79+
80+
if self._task == 0:
81+
return
82+
83+
# Fine-tuning on sub-set new + examplars
84+
self._build_examplars(train_loader)
85+
train_loader.dataset.set_idxes(self.examplars) # Fine-tuning only on balanced dataset
86+
optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.01, 0.0001)
87+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20], gamma=0.1)
88+
self.foo = 1
89+
self._train(train_loader, 1, optimizer, scheduler)
90+
91+
def _after_task(self, data_loader):
92+
self._reduce_examplars()
93+
self._build_examplars(data_loader)
94+
95+
def _eval_task(self, data_loader):
96+
ypred, ytrue = self._classify(data_loader)
97+
assert ypred.shape == ytrue.shape
98+
99+
return ypred, ytrue
100+
101+
def get_memory_indexes(self):
102+
return self.examplars
103+
104+
# -----------
105+
# Private API
106+
# -----------
107+
108+
def _train(self, train_loader, n_epochs, optimizer, scheduler):
109+
print("nb ", len(train_loader.dataset))
110+
111+
prog_bar = trange(n_epochs, desc="Losses.")
112+
113+
for epoch in prog_bar:
114+
_clf_loss, _distil_loss = 0., 0.
115+
c = 0
116+
117+
scheduler.step()
118+
119+
for i, ((_, idxes), inputs, targets) in enumerate(train_loader, start=1):
120+
optimizer.zero_grad()
121+
122+
c += len(idxes)
123+
inputs, targets = inputs.to(self._device), targets.to(self._device)
124+
logits = self.forward(inputs)
125+
126+
clf_loss, distil_loss = self._compute_loss(
127+
logits,
128+
targets,
129+
idxes,
130+
)
131+
132+
if not utils._check_loss(clf_loss) or not utils._check_loss(distil_loss):
133+
import pdb
134+
pdb.set_trace()
135+
136+
loss = clf_loss + distil_loss
137+
138+
loss.backward()
139+
optimizer.step()
140+
141+
_clf_loss += clf_loss.item()
142+
_distil_loss += distil_loss.item()
143+
144+
if i % 10 == 0 or i >= len(train_loader):
145+
prog_bar.set_description(
146+
"Clf loss: {}; Distill loss: {}".format(
147+
round(clf_loss.item(), 3), round(distil_loss.item(), 3)
148+
)
149+
)
150+
151+
prog_bar.set_description(
152+
"Clf loss: {}; Distill loss: {}".format(
153+
round(_clf_loss / c, 3), round(_distil_loss / c, 3)
154+
)
155+
)
156+
157+
def _compute_loss(self, logits, targets, idxes):
158+
"""Computes the classification loss & the distillation loss.
159+
160+
Distillation loss is null at the first task.
161+
162+
:param logits: Logits produced the model.
163+
:param targets: The targets.
164+
:param idxes: The real indexes of the just-processed images. Needed to
165+
match the previous predictions.
166+
:return: A tuple of the classification loss and the distillation loss.
167+
"""
168+
if self._task == 0:
169+
clf_loss = F.cross_entropy(logits, targets)
170+
distil_loss = torch.zeros(1, device=self._device)
171+
else:
172+
# Disable the cross_entropy loss for the old targets:
173+
for i in range(self._new_task_index):
174+
targets[targets == i] = -1
175+
clf_loss = F.cross_entropy(logits, targets, ignore_index=-1)
176+
177+
distil_loss = F.binary_cross_entropy(
178+
F.softmax(logits[..., :self._new_task_index] ** (1 / self._temperature), dim=1),
179+
F.softmax(self._previous_preds[idxes]**(1 / self._temperature), dim=1)
180+
)
181+
182+
return clf_loss, distil_loss
183+
184+
def _compute_predictions(self, loader):
185+
"""Precomputes the logits before a task.
186+
187+
:param data_loader: A DataLoader.
188+
:return: A tensor storing the whole current dataset logits.
189+
"""
190+
logits = torch.zeros(self._n_train_data, self._n_classes, device=self._device)
191+
192+
for idxes, inputs, _ in loader:
193+
inputs = inputs.to(self._device)
194+
idxes = idxes[1].to(self._device)
195+
196+
logits[idxes] = self.forward(inputs).detach()
197+
198+
return logits
199+
200+
def _classify(self, loader):
201+
"""Classify the images given by the data loader.
202+
203+
:param data_loader: A DataLoader.
204+
:return: A numpy array of the predicted targets and a numpy array of the
205+
ground-truth targets.
206+
"""
207+
ypred = []
208+
ytrue = []
209+
210+
for _, inputs, targets in loader:
211+
inputs = inputs.to(self._device)
212+
logits = self.forward(inputs)
213+
preds = F.softmax(logits, dim=1).argmax(dim=1)
214+
215+
ypred.extend(preds)
216+
ytrue.extend(targets)
217+
218+
return np.array(ypred), np.array(ytrue)
219+
220+
@property
221+
def _m(self):
222+
"""Returns the number of examplars per class."""
223+
return self._k // self._n_classes
224+
225+
def _add_n_classes(self, n):
226+
self._n_classes += n
227+
228+
weights = self._classifier.weight.data
229+
self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes,
230+
bias=False).to(self._device)
231+
torch.nn.init.kaiming_normal_(self._classifier.weight)
232+
233+
self._classifier.weight.data[:self._n_classes - n] = weights
234+
235+
print("Now {} examplars per class.".format(self._m))
236+
237+
def _extract_features(self, loader):
238+
features = []
239+
idxes = []
240+
241+
for (real_idxes, _), inputs, _ in loader:
242+
inputs = inputs.to(self._device)
243+
features.append(self._features_extractor(inputs).detach())
244+
idxes.extend(real_idxes.numpy().tolist())
245+
246+
features = torch.cat(features)
247+
mean = torch.mean(features, dim=0, keepdim=False)
248+
249+
return features, mean, idxes
250+
251+
@staticmethod
252+
def _get_closest(centers, features):
253+
"""Returns the center index being the closest to each feature.
254+
255+
:param centers: Centers to compare, in this case the class means.
256+
:param features: A tensor of features extracted by the convnet.
257+
:return: A numpy array of the closest centers indexes.
258+
"""
259+
pred_labels = []
260+
261+
features = features
262+
for feature in features:
263+
distances = End2End._dist(centers, feature)
264+
pred_labels.append(distances.argmin().item())
265+
266+
return np.array(pred_labels)
267+
268+
@staticmethod
269+
def _dist(a, b):
270+
"""Computes L2 distance between two tensors.
271+
272+
:param a: A tensor.
273+
:param b: A tensor.
274+
:return: A tensor of distance being of the shape of the "biggest" input
275+
tensor.
276+
"""
277+
return torch.pow(a - b, 2).sum(-1)
278+
279+
def _build_examplars(self, loader):
280+
"""Builds new examplars.
281+
282+
:param loader: A DataLoader.
283+
"""
284+
lo, hi = self._task * self._task_size, self._n_classes
285+
print("Building examplars for classes {} -> {}.".format(lo, hi))
286+
for class_idx in range(lo, hi):
287+
loader.dataset.set_classes_range(class_idx, class_idx)
288+
self._examplars[class_idx] = self._build_class_examplars(loader)
289+
290+
def _build_class_examplars(self, loader):
291+
"""Build examplars for a single class.
292+
293+
Examplars are selected as the closest to the class mean.
294+
295+
:param loader: DataLoader that provides images for a single class.
296+
:return: The real indexes of the chosen examplars.
297+
"""
298+
features, class_mean, idxes = self._extract_features(loader)
299+
300+
class_mean = F.normalize(class_mean, dim=0)
301+
distances_to_mean = self._dist(class_mean, features)
302+
303+
nb_examplars = min(self._m, len(features))
304+
305+
fake_idxes = distances_to_mean.argsort().cpu().numpy()[:nb_examplars]
306+
return [idxes[idx] for idx in fake_idxes]
307+
308+
@property
309+
def examplars(self):
310+
"""Returns all the real examplars indexes.
311+
312+
:return: A numpy array of indexes.
313+
"""
314+
return np.array(
315+
[
316+
examplar_idx for class_examplars in self._examplars.values()
317+
for examplar_idx in class_examplars
318+
]
319+
)
320+
321+
def _reduce_examplars(self):
322+
print("Reducing examplars.")
323+
for class_idx in range(len(self._examplars)):
324+
self._examplars[class_idx] = self._examplars[class_idx][:self._m]

0 commit comments

Comments
 (0)